From aee752b681259e26ccf37308654805b610e849e8 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 12 Jun 2026 14:01:55 -0700 Subject: [PATCH 1/6] Decode prompt token trace payloads Hydrate prompt_token_ids from Fireworks tracing payloads so RemoteRolloutProcessor can pass token-native prompt IDs through assistant turn metadata. Co-authored-by: Cursor --- eval_protocol/adapters/fireworks_tracing.py | 16 ++ eval_protocol/adapters/pti_deserializer.py | 98 ++++++++ eval_protocol/pytest/tracing_utils.py | 2 + .../test_remote_rollout_prompt_token_ids.py | 209 ++++++++++++++++++ ...test_fireworks_tracing_prompt_token_ids.py | 73 ++++++ tests/pytest/test_tracing_utils.py | 4 + 6 files changed, 402 insertions(+) create mode 100644 eval_protocol/adapters/pti_deserializer.py create mode 100644 scripts/test_remote_rollout_prompt_token_ids.py create mode 100644 tests/adapters/test_fireworks_tracing_prompt_token_ids.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 62a632e6..380dc876 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -17,6 +17,7 @@ from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter from .lp_deserializer import decompress_and_parse_lp +from .pti_deserializer import decompress_and_parse_pti from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -142,6 +143,21 @@ def convert_trace_dict_to_evaluation_row( e, ) + prompt_ids_payload = payloads.get("prompt_token_ids") + if isinstance(prompt_ids_payload, dict) and prompt_ids_payload.get("data"): + try: + prompt_token_ids, pti_meta = decompress_and_parse_pti(prompt_ids_payload["data"]) + if execution_metadata.extra is None: + execution_metadata.extra = {} + execution_metadata.extra["prompt_token_ids"] = prompt_token_ids + execution_metadata.extra["prompt_token_ids_metadata"] = pti_meta + except Exception as e: + logger.warning( + "Failed to decompress prompt token IDs payload for trace %s: %s", + trace.get("id"), + e, + ) + return EvaluationRow( messages=messages, tools=tools, diff --git a/eval_protocol/adapters/pti_deserializer.py b/eval_protocol/adapters/pti_deserializer.py new file mode 100644 index 00000000..cae74cc0 --- /dev/null +++ b/eval_protocol/adapters/pti_deserializer.py @@ -0,0 +1,98 @@ +"""PTI/v1 binary deserializer for prompt token ID payloads. + +Implements the inverse of the tracing gateway's +``prompt_token_ids_serializer.serialize_prompt_token_ids``. +""" + +from __future__ import annotations + +import base64 +import struct +from typing import Any, Dict, List, Tuple + +import zstandard as zstd + +MAGIC = b"PTI1" +HEADER_VERSION = 1 +ENTRY_FORMAT = " Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError(f"Payload too short for PTI/v1 header: {len(raw)} < {HEADER_SIZE}") + + ( + magic, + version, + flags, + reserved_u16, + token_count, + body_byte_length, + reserved_u64, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad PTI/v1 magic: {magic!r}") + if version != HEADER_VERSION: + raise ValueError(f"Unsupported PTI/v1 header version: {version}") + + return { + "flags": flags, + "reserved_u16": reserved_u16, + "token_count": token_count, + "body_byte_length": body_byte_length, + "reserved_u64": reserved_u64, + } + + +def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: + """Parse uncompressed PTI/v1 bytes into prompt token IDs and metadata.""" + header = _parse_header(raw) + token_count = header["token_count"] + body_byte_length = header["body_byte_length"] + + if token_count == 0: + raise ValueError("PTI/v1 token_count must be > 0") + if body_byte_length != token_count * ENTRY_SIZE: + raise ValueError( + f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " + f"({token_count * ENTRY_SIZE})" + ) + + expected_len = HEADER_SIZE + body_byte_length + if len(raw) != expected_len: + raise ValueError(f"PTI/v1 payload length mismatch: {len(raw)} != {expected_len}") + + token_ids: List[int] = [] + offset = HEADER_SIZE + for _ in range(token_count): + (token_id,) = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) + offset += ENTRY_SIZE + token_ids.append(token_id) + + metadata: Dict[str, Any] = { + "scope": "prompt_only", + "token_count": token_count, + } + header.update(metadata) + return token_ids, header + + +def decompress_and_parse_pti(data_b64: str) -> Tuple[List[int], Dict[str, Any]]: + """Decompress and unpack a PTI/v1 prompt token ID payload. + + Args: + data_b64: Base64-encoded zstd-compressed PTI binary blob from + ``payloads.prompt_token_ids.data``. + + Returns: + ``(token_ids, metadata)`` where ``token_ids`` is the prompt token ID + sequence and ``metadata`` includes ``token_count``. + """ + compressed = base64.b64decode(data_b64) + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed) + return parse_prompt_token_ids(raw) diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 279d1055..e0d5db73 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval for key in ( "completion_logprobs", "completion_token_ids", + "prompt_token_ids", + "prompt_token_ids_metadata", "logprobs_metadata", "routing_matrices", "routing_metadata", diff --git a/scripts/test_remote_rollout_prompt_token_ids.py b/scripts/test_remote_rollout_prompt_token_ids.py new file mode 100644 index 00000000..a143772f --- /dev/null +++ b/scripts/test_remote_rollout_prompt_token_ids.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads. + +This starts a tiny local `/init` server, sends one chat completion through the +Fireworks tracing gateway with `return_token_ids`, and verifies that +RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import socket +import threading +import time +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI + +from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger("remote_rollout_prompt_token_ids") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]: + if isinstance(message, Message): + return message.dump_mdoel_for_chat_completion_request() + return {k: v for k, v in dict(message).items() if v is not None} + + +def _make_app(gateway_url: str) -> FastAPI: + app = FastAPI() + app_logger = logging.getLogger(f"{__name__}.server") + app_logger.setLevel(logging.INFO) + + @app.get("/") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/init") + def init(req: InitRequest) -> dict[str, str]: + rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers): + rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url)) + rollout_logger.setLevel(logging.INFO) + + def _worker() -> None: + try: + conversation = [_message_to_dict(message) for message in (req.messages or [])] + params = dict(req.completion_params or {}) + params.pop("base_url", None) + params["extra_body"] = { + **dict(params.get("extra_body") or {}), + "return_token_ids": True, + } + params.setdefault("temperature", 0) + params.setdefault("max_tokens", 8) + + if not req.model_base_url: + raise ValueError("model_base_url is required") + if not params.get("model"): + raise ValueError("completion_params.model is required") + + client = OpenAI(base_url=req.model_base_url, api_key=req.api_key) + response = client.chat.completions.create(messages=conversation, **params) + content = response.choices[0].message.content or "" + logger.info("remote server generated content=%r", content) + + rollout_logger.info( + "rollout %s finished", + req.metadata.rollout_id, + extra={"status": Status.rollout_finished()}, + ) + except Exception as exc: + rollout_logger.exception( + "rollout %s failed", + req.metadata.rollout_id, + extra={"status": Status.rollout_unknown_error(str(exc))}, + ) + + threading.Thread(target=_worker, daemon=True).start() + return {"status": "started"} + + return app + + +def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None: + import requests + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + return + except Exception: + pass + time.sleep(0.2) + raise TimeoutError(f"server not ready: {url}") + + +async def _run(args: argparse.Namespace) -> None: + api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY") + if not api_key: + raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY") + + # FireworksTracingHttpHandler reads FIREWORKS_API_KEY. + os.environ["FIREWORKS_API_KEY"] = api_key + os.environ["EP_REMOTE_API_KEY"] = api_key + + port = args.port or _free_port() + remote_base_url = f"http://127.0.0.1:{port}" + app = _make_app(args.gateway_url) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + _wait_ready(f"{remote_base_url}/") + + rollout_id = f"rrp-prompt-ids-{int(time.time())}" + row = EvaluationRow( + messages=[Message(role="user", content="Reply with exactly: ok")], + ) + row.input_metadata.row_id = "row-0" + row.input_metadata.completion_params = { + "model": args.model, + "base_url": args.api_base_url, + "temperature": 0, + "max_tokens": 8, + } + row.execution_metadata.rollout_id = rollout_id + row.execution_metadata.invocation_id = "inv-0" + row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e" + row.execution_metadata.run_id = "run-0" + + processor = RemoteRolloutProcessor( + remote_base_url=remote_base_url, + model_base_url=args.gateway_url, + include_payloads=True, + timeout_seconds=args.timeout_seconds, + poll_interval=args.poll_interval, + ) + try: + task = processor( + [row], + RolloutProcessorConfig( + completion_params=row.input_metadata.completion_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(1), + steps=1, + ), + )[0] + completed = await task + finally: + await processor.acleanup() + server.should_exit = True + thread.join(timeout=5) + + extra = completed.execution_metadata.extra or {} + turn_payloads = extra.get("assistant_turn_payloads") or [] + prompt_ids = None + if turn_payloads: + prompt_ids = turn_payloads[0].get("prompt_token_ids") + if prompt_ids is None: + prompt_ids = extra.get("prompt_token_ids") + + print(f"rollout_id={rollout_id}") + print(f"messages={len(completed.messages)}") + print(f"assistant_turn_payloads={turn_payloads}") + print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}") + print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}") + + if not isinstance(prompt_ids, list) or not prompt_ids: + raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app")) + parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1")) + parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly")) + parser.add_argument("--api-key", default=None) + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--timeout-seconds", type=float, default=180.0) + parser.add_argument("--poll-interval", type=float, default=2.0) + asyncio.run(_run(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py new file mode 100644 index 00000000..869bac29 --- /dev/null +++ b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py @@ -0,0 +1,73 @@ +"""Tests for prompt token ID payload handling in fireworks_tracing adapter.""" + +from __future__ import annotations + +import base64 +import struct + +import pytest +import zstandard as zstd + +pytest.importorskip("mcp") + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.adapters.pti_deserializer import ( + ENTRY_FORMAT, + ENTRY_SIZE, + HEADER_FORMAT, + MAGIC, + decompress_and_parse_pti, +) + + +def _pti_b64(token_ids: list[int]) -> str: + token_count = len(token_ids) + body_byte_length = token_count * ENTRY_SIZE + header = struct.pack( + HEADER_FORMAT, + MAGIC, + 1, + 0, + 0, + token_count, + body_byte_length, + 0, + ) + body = b"".join(struct.pack(ENTRY_FORMAT, token_id) for token_id in token_ids) + compressed = zstd.ZstdCompressor().compress(header + body) + return base64.b64encode(compressed).decode("ascii") + + +def test_decompress_and_parse_pti_round_trip(): + token_ids, metadata = decompress_and_parse_pti(_pti_b64([101, 102, 103])) + + assert token_ids == [101, 102, 103] + assert metadata["scope"] == "prompt_only" + assert metadata["token_count"] == 3 + + +def test_trace_adapter_attaches_prompt_token_ids_metadata(): + trace = { + "id": "trace-pti", + "input": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, + "output": {"role": "assistant", "content": "hello"}, + "payloads": { + "prompt_token_ids": { + "data": _pti_b64([201, 202, 203]), + "manifest": {"PayloadVersion": "pti/v1"}, + }, + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + + assert row is not None + extra = row.execution_metadata.extra + assert extra is not None + assert extra["prompt_token_ids"] == [201, 202, 203] + assert extra["prompt_token_ids_metadata"]["token_count"] == 3 diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py index 58ec55c1..98246ab1 100644 --- a/tests/pytest/test_tracing_utils.py +++ b/tests/pytest/test_tracing_utils.py @@ -13,6 +13,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -28,6 +29,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -45,12 +47,14 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): { "assistant_turn_index": 0, "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, { "assistant_turn_index": 1, "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, From 97ffb52165a3199e72dc20d6a95effac52debfc5 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Mon, 15 Jun 2026 17:25:46 -0700 Subject: [PATCH 2/6] Extract gateway payload decoders into standalone eval_protocol.tracing Move the per-payload binary deserializers out of adapters/ into a dependency-light eval_protocol/tracing package: a PayloadType StrEnum, DecodedPayload, and a decode_payloads registry (master decode) usable without EvaluationRow/rollout machinery. Refactor FireworksTracingAdapter to use the registry instead of three copy-pasted decode blocks, and decode pti/v1 as zstd(JSON int array) to match the gateway. Adds tests/tracing and a README. Co-authored-by: Cursor --- eval_protocol/adapters/fireworks_tracing.py | 93 ++++++------- eval_protocol/adapters/pti_deserializer.py | 98 -------------- eval_protocol/tracing/README.md | 109 +++++++++++++++ eval_protocol/tracing/__init__.py | 43 ++++++ eval_protocol/tracing/_decompress.py | 17 +++ .../logprobs.py} | 17 +++ eval_protocol/tracing/prompt_token_ids.py | 31 +++++ eval_protocol/tracing/registry.py | 88 +++++++++++++ .../router_replay.py} | 15 +++ eval_protocol/tracing/types.py | 40 ++++++ .../test_fireworks_tracing_logprobs.py | 2 +- ...test_fireworks_tracing_prompt_token_ids.py | 38 ++---- tests/adapters/test_lp_deserializer.py | 2 +- tests/adapters/test_r3_deserializer.py | 2 +- tests/tracing/test_registry.py | 124 ++++++++++++++++++ 15 files changed, 535 insertions(+), 184 deletions(-) delete mode 100644 eval_protocol/adapters/pti_deserializer.py create mode 100644 eval_protocol/tracing/README.md create mode 100644 eval_protocol/tracing/__init__.py create mode 100644 eval_protocol/tracing/_decompress.py rename eval_protocol/{adapters/lp_deserializer.py => tracing/logprobs.py} (86%) create mode 100644 eval_protocol/tracing/prompt_token_ids.py create mode 100644 eval_protocol/tracing/registry.py rename eval_protocol/{adapters/r3_deserializer.py => tracing/router_replay.py} (93%) create mode 100644 eval_protocol/tracing/types.py create mode 100644 tests/tracing/test_registry.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 380dc876..89a2bc94 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -15,10 +15,8 @@ import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from eval_protocol.tracing import PayloadType, decode_payloads from .base import BaseAdapter -from .lp_deserializer import decompress_and_parse_lp -from .pti_deserializer import decompress_and_parse_pti -from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -103,60 +101,45 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need - # Extract router replay payloads when present + # Decode out-of-band gateway payloads (router replay, logprobs, prompt + # token ids) via the standalone tracing decoder registry, then map the + # decoded values onto the row. Format/decoding lives in + # ``eval_protocol.tracing``; this adapter only does EvaluationRow glue. payloads = trace.get("payloads") if isinstance(payloads, dict): - router_replay = payloads.get("router_replay") - if isinstance(router_replay, dict) and router_replay.get("data"): - try: - matrices, r3_meta = decompress_and_parse_r3(router_replay["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["routing_matrices"] = matrices - execution_metadata.extra["routing_metadata"] = r3_meta - except Exception as e: - logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e) - - logprobs_payload = payloads.get("logprobs") - if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"): - try: - logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["completion_logprobs"] = logprobs - if token_ids is not None: - execution_metadata.extra["completion_token_ids"] = token_ids - execution_metadata.extra["logprobs_metadata"] = lp_meta - - for i in range(len(messages) - 1, -1, -1): - if messages[i].role == "assistant": - content_entries = [{"logprob": lp} for lp in logprobs] - if token_ids is not None: - for entry, tid in zip(content_entries, token_ids): - entry["token_id"] = tid - messages[i].logprobs = {"content": content_entries} - break - except Exception as e: - logger.warning( - "Failed to decompress logprobs payload for trace %s: %s", - trace.get("id"), - e, - ) - - prompt_ids_payload = payloads.get("prompt_token_ids") - if isinstance(prompt_ids_payload, dict) and prompt_ids_payload.get("data"): - try: - prompt_token_ids, pti_meta = decompress_and_parse_pti(prompt_ids_payload["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["prompt_token_ids"] = prompt_token_ids - execution_metadata.extra["prompt_token_ids_metadata"] = pti_meta - except Exception as e: - logger.warning( - "Failed to decompress prompt token IDs payload for trace %s: %s", - trace.get("id"), - e, - ) + decoded = decode_payloads( + payloads, + on_error=lambda pt, e: logger.warning( + "Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e + ), + ) + if decoded and execution_metadata.extra is None: + execution_metadata.extra = {} + + if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None: + execution_metadata.extra["routing_matrices"] = dp.value + execution_metadata.extra["routing_metadata"] = dp.metadata + + if (dp := decoded.get(PayloadType.LOGPROBS)) is not None: + logprobs = dp.value + token_ids = dp.extras.get("token_ids") + execution_metadata.extra["completion_logprobs"] = logprobs + if token_ids is not None: + execution_metadata.extra["completion_token_ids"] = token_ids + execution_metadata.extra["logprobs_metadata"] = dp.metadata + + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "assistant": + content_entries = [{"logprob": lp} for lp in logprobs] + if token_ids is not None: + for entry, tid in zip(content_entries, token_ids): + entry["token_id"] = tid + messages[i].logprobs = {"content": content_entries} + break + + if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None: + execution_metadata.extra["prompt_token_ids"] = dp.value + execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata return EvaluationRow( messages=messages, diff --git a/eval_protocol/adapters/pti_deserializer.py b/eval_protocol/adapters/pti_deserializer.py deleted file mode 100644 index cae74cc0..00000000 --- a/eval_protocol/adapters/pti_deserializer.py +++ /dev/null @@ -1,98 +0,0 @@ -"""PTI/v1 binary deserializer for prompt token ID payloads. - -Implements the inverse of the tracing gateway's -``prompt_token_ids_serializer.serialize_prompt_token_ids``. -""" - -from __future__ import annotations - -import base64 -import struct -from typing import Any, Dict, List, Tuple - -import zstandard as zstd - -MAGIC = b"PTI1" -HEADER_VERSION = 1 -ENTRY_FORMAT = " Dict[str, Any]: - if len(raw) < HEADER_SIZE: - raise ValueError(f"Payload too short for PTI/v1 header: {len(raw)} < {HEADER_SIZE}") - - ( - magic, - version, - flags, - reserved_u16, - token_count, - body_byte_length, - reserved_u64, - ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) - - if magic != MAGIC: - raise ValueError(f"Bad PTI/v1 magic: {magic!r}") - if version != HEADER_VERSION: - raise ValueError(f"Unsupported PTI/v1 header version: {version}") - - return { - "flags": flags, - "reserved_u16": reserved_u16, - "token_count": token_count, - "body_byte_length": body_byte_length, - "reserved_u64": reserved_u64, - } - - -def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: - """Parse uncompressed PTI/v1 bytes into prompt token IDs and metadata.""" - header = _parse_header(raw) - token_count = header["token_count"] - body_byte_length = header["body_byte_length"] - - if token_count == 0: - raise ValueError("PTI/v1 token_count must be > 0") - if body_byte_length != token_count * ENTRY_SIZE: - raise ValueError( - f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " - f"({token_count * ENTRY_SIZE})" - ) - - expected_len = HEADER_SIZE + body_byte_length - if len(raw) != expected_len: - raise ValueError(f"PTI/v1 payload length mismatch: {len(raw)} != {expected_len}") - - token_ids: List[int] = [] - offset = HEADER_SIZE - for _ in range(token_count): - (token_id,) = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) - offset += ENTRY_SIZE - token_ids.append(token_id) - - metadata: Dict[str, Any] = { - "scope": "prompt_only", - "token_count": token_count, - } - header.update(metadata) - return token_ids, header - - -def decompress_and_parse_pti(data_b64: str) -> Tuple[List[int], Dict[str, Any]]: - """Decompress and unpack a PTI/v1 prompt token ID payload. - - Args: - data_b64: Base64-encoded zstd-compressed PTI binary blob from - ``payloads.prompt_token_ids.data``. - - Returns: - ``(token_ids, metadata)`` where ``token_ids`` is the prompt token ID - sequence and ``metadata`` includes ``token_count``. - """ - compressed = base64.b64decode(data_b64) - decompressor = zstd.ZstdDecompressor() - raw = decompressor.decompress(compressed) - return parse_prompt_token_ids(raw) diff --git a/eval_protocol/tracing/README.md b/eval_protocol/tracing/README.md new file mode 100644 index 00000000..4d1571ae --- /dev/null +++ b/eval_protocol/tracing/README.md @@ -0,0 +1,109 @@ +# `eval_protocol.tracing` — Fireworks tracing-gateway payload decoders + +Standalone helpers for decoding the out-of-band **payloads** the Fireworks +tracing gateway stores alongside a trace (prompt token IDs, completion logprobs, +router-replay routing matrices). + +This package is intentionally self-contained: it depends only on the stdlib and +`zstandard`. It does **not** import `EvaluationRow`, rollout processors, or any +other Eval Protocol machinery, so you can use it even if you are not using EP for +rollouts — just point at it for extracting gateway payloads. + +## What is a "payload"? + +When you read a trace with payloads included: + +``` +GET {gateway}/v1/traces?rollout_id=...&include_payloads=true +``` + +each trace carries a `payloads` object like: + +```json +{ + "payloads": { + "prompt_token_ids": { + "manifest": { "PayloadVersion": "pti/v1", "...": "..." }, + "data": "" + }, + "logprobs": { "manifest": { "PayloadVersion": "lp/v1" }, "data": "..." }, + "router_replay": { "manifest": { "PayloadVersion": "r3/v1" }, "data": "..." } + } +} +``` + +The `data` field is `base64(zstd(raw_bytes))`. Each payload type has its own +`raw_bytes` encoding (`pti/v1` is a JSON int array; `lp/v1` and `r3/v1` are packed +binary). This package hides all of that. + +## Usage + +Decode everything at once (the common case): + +```python +from eval_protocol.tracing import decode_payloads, PayloadType + +decoded = decode_payloads(trace["payloads"]) + +if PayloadType.PROMPT_TOKEN_IDS in decoded: + token_ids = decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + +if PayloadType.LOGPROBS in decoded: + lp = decoded[PayloadType.LOGPROBS] + logprobs = lp.value # List[float] + token_ids = lp.extras.get("token_ids") # Optional[List[int]] + +if PayloadType.ROUTER_REPLAY in decoded: + matrices = decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] +``` + +If you have the whole trace dict, `decode_trace(trace)` reaches into +`trace["payloads"]` for you. + +Decode a single payload: + +```python +from eval_protocol.tracing import decode_payload, PayloadType + +dp = decode_payload(PayloadType.PROMPT_TOKEN_IDS, trace["payloads"]["prompt_token_ids"]["data"]) +dp.value # List[int] +``` + +### Error handling + +`decode_payloads` isolates per-payload failures: if one payload fails to decode, +the others are still returned. Pass `on_error=callback(payload_type, exc)` to +control logging (defaults to a warning): + +```python +decode_payloads(payloads, on_error=lambda pt, e: print(f"{pt} failed: {e}")) +``` + +## Return type + +`decode_payloads` / `decode_trace` return `Dict[PayloadType, DecodedPayload]`. + +`DecodedPayload` fields: + +| field | meaning | +|----------------|-------------------------------------------------------------------| +| `payload_type` | `PayloadType` enum member | +| `value` | decoded value (type depends on `payload_type`, see below) | +| `metadata` | decoded header/manifest metadata (token counts, scope, etc.) | +| `extras` | type-specific extras (e.g. logprobs `token_ids`) | + +`value` by type: + +| `PayloadType` | `value` | notes | +|---------------------|--------------------------|----------------------------------------------| +| `PROMPT_TOKEN_IDS` | `List[int]` | prompt token ids | +| `LOGPROBS` | `List[float]` | per completion token; ids in `extras["token_ids"]` (or `None`) | +| `ROUTER_REPLAY` | `List[Optional[str]]` | per-token base64 routing matrices; `None` where absent | + +## Adding a new payload type + +1. Add a member to `PayloadType` in `types.py`. +2. Add a `decode_(data_b64) -> DecodedPayload` function in a new module. +3. Register it in `PAYLOAD_DECODERS` in `registry.py`. + +`decode_payloads` picks it up automatically. diff --git a/eval_protocol/tracing/__init__.py b/eval_protocol/tracing/__init__.py new file mode 100644 index 00000000..5870c2a4 --- /dev/null +++ b/eval_protocol/tracing/__init__.py @@ -0,0 +1,43 @@ +"""Decode Fireworks tracing-gateway payloads. + +Standalone, dependency-light helpers (stdlib + ``zstandard`` only) for turning +the binary/JSON ``payloads`` returned by the Fireworks tracing gateway +(``GET /traces?include_payloads=true``) into Python values. No EvaluationRow or +rollout machinery required -- usable on its own. + +Typical use:: + + from eval_protocol.tracing import decode_payloads, PayloadType + + decoded = decode_payloads(trace["payloads"]) + decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + decoded[PayloadType.LOGPROBS].value # List[float] + decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] + +See ``README.md`` in this package for details. +""" + +from __future__ import annotations + +from .logprobs import decode_logprobs +from .prompt_token_ids import decode_prompt_token_ids +from .registry import ( + PAYLOAD_DECODERS, + decode_payload, + decode_payloads, + decode_trace, +) +from .router_replay import decode_router_replay +from .types import DecodedPayload, PayloadType + +__all__ = [ + "PayloadType", + "DecodedPayload", + "PAYLOAD_DECODERS", + "decode_payload", + "decode_payloads", + "decode_trace", + "decode_prompt_token_ids", + "decode_logprobs", + "decode_router_replay", +] diff --git a/eval_protocol/tracing/_decompress.py b/eval_protocol/tracing/_decompress.py new file mode 100644 index 00000000..1ab64bfd --- /dev/null +++ b/eval_protocol/tracing/_decompress.py @@ -0,0 +1,17 @@ +"""Shared helper for the Fireworks tracing-gateway payload decoders.""" + +from __future__ import annotations + +import base64 + +import zstandard as zstd + + +def decompress_b64(data_b64: str) -> bytes: + """Base64-decode then zstd-decompress a gateway ``payloads.*.data`` blob. + + The gateway stores every payload as ``base64(zstd(raw_bytes))``; this is the + common first step every decoder shares before interpreting ``raw_bytes``. + """ + compressed = base64.b64decode(data_b64) + return zstd.ZstdDecompressor().decompress(compressed) diff --git a/eval_protocol/adapters/lp_deserializer.py b/eval_protocol/tracing/logprobs.py similarity index 86% rename from eval_protocol/adapters/lp_deserializer.py rename to eval_protocol/tracing/logprobs.py index 57aa4f46..f5134108 100644 --- a/eval_protocol/adapters/lp_deserializer.py +++ b/eval_protocol/tracing/logprobs.py @@ -12,6 +12,8 @@ import zstandard as zstd +from .types import DecodedPayload, PayloadType + MAGIC = b"LP01" HEADER_VERSION = 1 MISSING_TOKEN_ID = -1 @@ -107,3 +109,18 @@ def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[i decompressor = zstd.ZstdDecompressor() raw = decompressor.decompress(compressed) return parse_logprobs(raw) + + +def decode_logprobs(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.logprobs.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-completion-token logprob list; per-token ids (when all + valid) are available under ``extras["token_ids"]``. + """ + logprobs, token_ids, metadata = decompress_and_parse_lp(data_b64) + return DecodedPayload( + payload_type=PayloadType.LOGPROBS, + value=logprobs, + metadata=metadata, + extras={"token_ids": token_ids}, + ) diff --git a/eval_protocol/tracing/prompt_token_ids.py b/eval_protocol/tracing/prompt_token_ids.py new file mode 100644 index 00000000..5019331a --- /dev/null +++ b/eval_protocol/tracing/prompt_token_ids.py @@ -0,0 +1,31 @@ +"""``pti/v1`` decoder for prompt token ID payloads. + +Inverse of the tracing gateway's ``serialize_prompt_token_ids``: the gateway +stores prompt token IDs as ``base64(zstd(json.dumps(token_ids)))`` -- a compact +JSON int array, no bespoke binary header. +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Tuple + +from ._decompress import decompress_b64 +from .types import DecodedPayload, PayloadType + + +def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: + """Parse uncompressed ``pti/v1`` bytes (a JSON int array) into ids + metadata.""" + token_ids = json.loads(raw) + metadata: Dict[str, Any] = {"scope": "prompt_only", "token_count": len(token_ids)} + return token_ids, metadata + + +def decode_prompt_token_ids(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.prompt_token_ids.data`` blob.""" + token_ids, metadata = parse_prompt_token_ids(decompress_b64(data_b64)) + return DecodedPayload( + payload_type=PayloadType.PROMPT_TOKEN_IDS, + value=token_ids, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/registry.py b/eval_protocol/tracing/registry.py new file mode 100644 index 00000000..62027376 --- /dev/null +++ b/eval_protocol/tracing/registry.py @@ -0,0 +1,88 @@ +"""Decoder registry + master decode for tracing-gateway payloads. + +Adding a new payload type is a single entry in ``PAYLOAD_DECODERS`` (plus its +decoder module). Callers use the master :func:`decode_payloads` / +:func:`decode_trace` and never stitch per-type decoders together. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Optional + +from .logprobs import decode_logprobs +from .prompt_token_ids import decode_prompt_token_ids +from .router_replay import decode_router_replay +from .types import DecodedPayload, PayloadType + +logger = logging.getLogger(__name__) + +# Callback invoked when a single payload fails to decode: (payload_type, exc). +OnError = Callable[[PayloadType, Exception], None] + +PAYLOAD_DECODERS: Dict[PayloadType, Callable[[str], DecodedPayload]] = { + PayloadType.PROMPT_TOKEN_IDS: decode_prompt_token_ids, + PayloadType.LOGPROBS: decode_logprobs, + PayloadType.ROUTER_REPLAY: decode_router_replay, +} + + +def decode_payload(payload_type: PayloadType | str, data_b64: str) -> DecodedPayload: + """Decode a single payload by type. + + ``payload_type`` accepts a ``PayloadType`` or its string value, so external + callers can pass either. + """ + ptype = PayloadType(payload_type) + decoder = PAYLOAD_DECODERS.get(ptype) + if decoder is None: + raise ValueError(f"No decoder registered for payload type: {ptype!r}") + return decoder(data_b64) + + +def decode_payloads( + payloads: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Master decode: run every registered decoder over a gateway ``payloads`` dict. + + Args: + payloads: The ``payloads`` object from a gateway trace (i.e. + ``trace["payloads"]``), mapping payload-type name -> ``{"manifest", "data"}``. + on_error: Optional callback ``(payload_type, exc)`` invoked when a present + payload fails to decode. Defaults to logging a warning. A failure in + one payload never blocks the others. + + Returns: + ``{PayloadType: DecodedPayload}`` for every payload that is present and + decodes successfully. Only known ``PayloadType`` members are considered, + so unknown payload types from a newer gateway are ignored rather than + raising. + """ + if not isinstance(payloads, dict): + return {} + + decoded: Dict[PayloadType, DecodedPayload] = {} + for ptype, decoder in PAYLOAD_DECODERS.items(): + entry = payloads.get(ptype) + if not isinstance(entry, dict) or not entry.get("data"): + continue + try: + decoded[ptype] = decoder(entry["data"]) + except Exception as exc: # noqa: BLE001 - isolate per-payload failures + if on_error is not None: + on_error(ptype, exc) + else: + logger.warning("Failed to decode %s payload: %s", ptype.value, exc) + return decoded + + +def decode_trace( + trace: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Convenience wrapper around :func:`decode_payloads` for a raw trace dict.""" + payloads = trace.get("payloads") if isinstance(trace, dict) else None + return decode_payloads(payloads or {}, on_error=on_error) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/tracing/router_replay.py similarity index 93% rename from eval_protocol/adapters/r3_deserializer.py rename to eval_protocol/tracing/router_replay.py index 1e3c1a8c..49c4be8e 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/tracing/router_replay.py @@ -20,6 +20,8 @@ import zstandard as zstd +from .types import DecodedPayload, PayloadType + MAGIC = b"R3V1" HEADER_FORMAT = "<4sBBBBIIIIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes @@ -185,3 +187,16 @@ def decompress_and_parse_r3( matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") return matrices, metadata + + +def decode_router_replay(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.router_replay.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-token ``List[Optional[str]]`` of base64 routing matrices. + """ + matrices, metadata = decompress_and_parse_r3(data_b64) + return DecodedPayload( + payload_type=PayloadType.ROUTER_REPLAY, + value=matrices, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/types.py b/eval_protocol/tracing/types.py new file mode 100644 index 00000000..f3760f28 --- /dev/null +++ b/eval_protocol/tracing/types.py @@ -0,0 +1,40 @@ +"""Shared types for the Fireworks tracing-gateway payload decoders.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict + + +class PayloadType(str, Enum): + """Known out-of-band trace payload types emitted by the tracing gateway. + + Canonical source of truth for payload-type names on the EP side. Mirrors the + gateway's ``rft_tracing.schemas.PayloadType`` but is defined locally so this + package has no dependency on the gateway/mono codebase. Being a ``str`` enum, + members compare and hash equal to their string value, so they can be used + directly against the gateway's string-keyed ``payloads`` JSON. + """ + + PROMPT_TOKEN_IDS = "prompt_token_ids" + LOGPROBS = "logprobs" + ROUTER_REPLAY = "router_replay" + + +@dataclass(frozen=True) +class DecodedPayload: + """A single decoded gateway payload. + + ``value`` shape depends on ``payload_type``: + - ``PROMPT_TOKEN_IDS`` -> ``List[int]`` + - ``LOGPROBS`` -> ``List[float]`` (per completion token); the + optional per-token ids are in ``extras["token_ids"]`` + - ``ROUTER_REPLAY`` -> ``List[Optional[str]]`` (per-token base64 routing + matrices, ``None`` where absent) + """ + + payload_type: PayloadType + value: Any + metadata: Dict[str, Any] + extras: Dict[str, Any] = field(default_factory=dict) diff --git a/tests/adapters/test_fireworks_tracing_logprobs.py b/tests/adapters/test_fireworks_tracing_logprobs.py index 08dab60b..c38c7acd 100644 --- a/tests/adapters/test_fireworks_tracing_logprobs.py +++ b/tests/adapters/test_fireworks_tracing_logprobs.py @@ -11,7 +11,7 @@ pytest.importorskip("mcp") from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py index 869bac29..16bebe3f 100644 --- a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py +++ b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py @@ -3,7 +3,7 @@ from __future__ import annotations import base64 -import struct +import json import pytest import zstandard as zstd @@ -11,39 +11,21 @@ pytest.importorskip("mcp") from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row -from eval_protocol.adapters.pti_deserializer import ( - ENTRY_FORMAT, - ENTRY_SIZE, - HEADER_FORMAT, - MAGIC, - decompress_and_parse_pti, -) +from eval_protocol.tracing.prompt_token_ids import decode_prompt_token_ids def _pti_b64(token_ids: list[int]) -> str: - token_count = len(token_ids) - body_byte_length = token_count * ENTRY_SIZE - header = struct.pack( - HEADER_FORMAT, - MAGIC, - 1, - 0, - 0, - token_count, - body_byte_length, - 0, - ) - body = b"".join(struct.pack(ENTRY_FORMAT, token_id) for token_id in token_ids) - compressed = zstd.ZstdCompressor().compress(header + body) - return base64.b64encode(compressed).decode("ascii") + """Build a gateway pti/v1 payload: base64(zstd(json int array)).""" + raw = json.dumps(token_ids).encode("utf-8") + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") -def test_decompress_and_parse_pti_round_trip(): - token_ids, metadata = decompress_and_parse_pti(_pti_b64([101, 102, 103])) +def test_decode_prompt_token_ids_round_trip(): + decoded = decode_prompt_token_ids(_pti_b64([101, 102, 103])) - assert token_ids == [101, 102, 103] - assert metadata["scope"] == "prompt_only" - assert metadata["token_count"] == 3 + assert decoded.value == [101, 102, 103] + assert decoded.metadata["scope"] == "prompt_only" + assert decoded.metadata["token_count"] == 3 def test_trace_adapter_attaches_prompt_token_ids_metadata(): diff --git a/tests/adapters/test_lp_deserializer.py b/tests/adapters/test_lp_deserializer.py index 52e04417..c5460614 100644 --- a/tests/adapters/test_lp_deserializer.py +++ b/tests/adapters/test_lp_deserializer.py @@ -8,7 +8,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 31f058f6..980d3934 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -10,7 +10,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.r3_deserializer import ( +from eval_protocol.tracing.router_replay import ( HEADER_FORMAT, HEADER_SIZE, MAGIC, diff --git a/tests/tracing/test_registry.py b/tests/tracing/test_registry.py new file mode 100644 index 00000000..b037adad --- /dev/null +++ b/tests/tracing/test_registry.py @@ -0,0 +1,124 @@ +"""Tests for the standalone tracing-gateway payload decoder registry.""" + +from __future__ import annotations + +import base64 +import json +import struct + +import zstandard as zstd + +from eval_protocol.tracing import ( + DecodedPayload, + PayloadType, + decode_payload, + decode_payloads, + decode_trace, +) +from eval_protocol.tracing import logprobs as lp_mod +from eval_protocol.tracing import router_replay as r3_mod + + +def _b64_zstd(raw: bytes) -> str: + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +def _pti_data(token_ids: list[int]) -> str: + return _b64_zstd(json.dumps(token_ids).encode("utf-8")) + + +def _lp_data(tokens: list[tuple[int, float]]) -> str: + body = b"".join(struct.pack(lp_mod.ENTRY_FORMAT, tid, lp) for tid, lp in tokens) + header = struct.pack( + lp_mod.HEADER_FORMAT, lp_mod.MAGIC, 1, 0, 0, len(tokens), len(body), 0 + ) + return _b64_zstd(header + body) + + +def _r3_data_all_mode(matrices: list[bytes]) -> str: + matrix_data = b"".join(matrices) + header = struct.pack( + r3_mod.HEADER_FORMAT, + r3_mod.MAGIC, + 1, # version + r3_mod._SelectorMode.ALL, + r3_mod._RoutingDtype.UINT8, + 0x01, # flags + len(matrices), # total_token_count + len(matrices), # replayed_token_count + 0, # replay_start_token + 0, # selector_byte_length + len(matrix_data), + ) + return _b64_zstd(header + matrix_data) + + +def _all_payloads() -> dict: + return { + "prompt_token_ids": {"manifest": {"PayloadVersion": "pti/v1"}, "data": _pti_data([1, 2, 3])}, + "logprobs": {"manifest": {"PayloadVersion": "lp/v1"}, "data": _lp_data([(7, -0.25), (8, -0.5)])}, + "router_replay": { + "manifest": {"PayloadVersion": "r3/v1"}, + "data": _r3_data_all_mode([b"\x01\x02\x03\x04", b"\x05\x06\x07\x08"]), + }, + } + + +def test_decode_payloads_all_types(): + decoded = decode_payloads(_all_payloads()) + + assert set(decoded) == { + PayloadType.PROMPT_TOKEN_IDS, + PayloadType.LOGPROBS, + PayloadType.ROUTER_REPLAY, + } + assert all(isinstance(dp, DecodedPayload) for dp in decoded.values()) + + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [1, 2, 3] + + lp = decoded[PayloadType.LOGPROBS] + assert lp.value == [-0.25, -0.5] + assert lp.extras["token_ids"] == [7, 8] + + r3 = decoded[PayloadType.ROUTER_REPLAY] + assert len(r3.value) == 2 + assert base64.b64decode(r3.value[0]) == b"\x01\x02\x03\x04" + + +def test_decode_payload_accepts_str_and_enum(): + data = _pti_data([10, 20]) + via_enum = decode_payload(PayloadType.PROMPT_TOKEN_IDS, data) + via_str = decode_payload("prompt_token_ids", data) + assert via_enum.value == via_str.value == [10, 20] + + +def test_decode_trace_reaches_into_payloads(): + trace = {"id": "t1", "payloads": {"prompt_token_ids": {"data": _pti_data([5, 6])}}} + decoded = decode_trace(trace) + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [5, 6] + + +def test_unknown_and_empty_types_are_skipped(): + payloads = { + "some_future_type": {"data": "ignored"}, # unknown -> ignored + "logprobs": {"data": ""}, # present but empty -> skipped + "prompt_token_ids": {"data": _pti_data([9])}, + } + decoded = decode_payloads(payloads) + assert set(decoded) == {PayloadType.PROMPT_TOKEN_IDS} + + +def test_on_error_fires_on_bad_data(): + errors = [] + payloads = {"prompt_token_ids": {"data": "not-valid-base64-zstd-json!!"}} + + decoded = decode_payloads(payloads, on_error=lambda pt, e: errors.append((pt, e))) + + assert decoded == {} + assert len(errors) == 1 + assert errors[0][0] == PayloadType.PROMPT_TOKEN_IDS + + +def test_decode_payloads_non_dict_returns_empty(): + assert decode_payloads(None) == {} + assert decode_trace({"id": "no-payloads"}) == {} From 62524fed9054447a30ce9c0c2d8a5a4a76179dd1 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 16 Jun 2026 14:56:12 -0700 Subject: [PATCH 3/6] Trim verbose comment in fireworks_tracing adapter Replace the 4-line narrating comment over the payload-decode block with a single intent line; the code already shows what it does. Co-authored-by: Cursor --- eval_protocol/adapters/fireworks_tracing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 89a2bc94..07cc987c 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -101,10 +101,7 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need - # Decode out-of-band gateway payloads (router replay, logprobs, prompt - # token ids) via the standalone tracing decoder registry, then map the - # decoded values onto the row. Format/decoding lives in - # ``eval_protocol.tracing``; this adapter only does EvaluationRow glue. + # Decoding lives in eval_protocol.tracing; here we only map results onto the row. payloads = trace.get("payloads") if isinstance(payloads, dict): decoded = decode_payloads( From d82431809c62d7cb104526d07b0f7d966bba2568 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 16 Jun 2026 15:11:09 -0700 Subject: [PATCH 4/6] Slim eval_protocol.tracing public API to the master decode surface Export only PayloadType, DecodedPayload, and decode_payloads/decode_payload/ decode_trace from the package __init__. The per-type decoders and PAYLOAD_DECODERS are internal building blocks, still reachable via submodules. Co-authored-by: Cursor --- eval_protocol/tracing/__init__.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/eval_protocol/tracing/__init__.py b/eval_protocol/tracing/__init__.py index 5870c2a4..7ad9d436 100644 --- a/eval_protocol/tracing/__init__.py +++ b/eval_protocol/tracing/__init__.py @@ -19,25 +19,13 @@ from __future__ import annotations -from .logprobs import decode_logprobs -from .prompt_token_ids import decode_prompt_token_ids -from .registry import ( - PAYLOAD_DECODERS, - decode_payload, - decode_payloads, - decode_trace, -) -from .router_replay import decode_router_replay +from .registry import decode_payload, decode_payloads, decode_trace from .types import DecodedPayload, PayloadType __all__ = [ "PayloadType", "DecodedPayload", - "PAYLOAD_DECODERS", - "decode_payload", "decode_payloads", + "decode_payload", "decode_trace", - "decode_prompt_token_ids", - "decode_logprobs", - "decode_router_replay", ] From 214fd2b0021af967f2168a54d93bd266fdcb15cd Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 16 Jun 2026 15:44:48 -0700 Subject: [PATCH 5/6] Inline base64+zstd decompress into prompt_token_ids; drop _decompress module The shared helper had a single caller (prompt_token_ids); logprobs/router_replay already inline the same base64+zstd step. Inline it there too for consistency and remove the dedicated _decompress.py file. Co-authored-by: Cursor --- eval_protocol/tracing/_decompress.py | 17 ----------------- eval_protocol/tracing/prompt_token_ids.py | 7 +++++-- 2 files changed, 5 insertions(+), 19 deletions(-) delete mode 100644 eval_protocol/tracing/_decompress.py diff --git a/eval_protocol/tracing/_decompress.py b/eval_protocol/tracing/_decompress.py deleted file mode 100644 index 1ab64bfd..00000000 --- a/eval_protocol/tracing/_decompress.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Shared helper for the Fireworks tracing-gateway payload decoders.""" - -from __future__ import annotations - -import base64 - -import zstandard as zstd - - -def decompress_b64(data_b64: str) -> bytes: - """Base64-decode then zstd-decompress a gateway ``payloads.*.data`` blob. - - The gateway stores every payload as ``base64(zstd(raw_bytes))``; this is the - common first step every decoder shares before interpreting ``raw_bytes``. - """ - compressed = base64.b64decode(data_b64) - return zstd.ZstdDecompressor().decompress(compressed) diff --git a/eval_protocol/tracing/prompt_token_ids.py b/eval_protocol/tracing/prompt_token_ids.py index 5019331a..ff9d68de 100644 --- a/eval_protocol/tracing/prompt_token_ids.py +++ b/eval_protocol/tracing/prompt_token_ids.py @@ -7,10 +7,12 @@ from __future__ import annotations +import base64 import json from typing import Any, Dict, List, Tuple -from ._decompress import decompress_b64 +import zstandard as zstd + from .types import DecodedPayload, PayloadType @@ -23,7 +25,8 @@ def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: def decode_prompt_token_ids(data_b64: str) -> DecodedPayload: """Decode a gateway ``payloads.prompt_token_ids.data`` blob.""" - token_ids, metadata = parse_prompt_token_ids(decompress_b64(data_b64)) + raw = zstd.ZstdDecompressor().decompress(base64.b64decode(data_b64)) + token_ids, metadata = parse_prompt_token_ids(raw) return DecodedPayload( payload_type=PayloadType.PROMPT_TOKEN_IDS, value=token_ids, From 38fe102291566a78d0e61603e1e73d1a47247e60 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 16 Jun 2026 16:08:19 -0700 Subject: [PATCH 6/6] Type DecodedPayload.token_ids instead of an untyped extras bag Replace the Dict[str, Any] `extras` field (only ever holding logprobs token_ids) with a typed `token_ids: Optional[List[int]]` field. Callers now get a real type (dp.token_ids) instead of Any from extras.get(...). Update adapter, tests, README. Co-authored-by: Cursor --- eval_protocol/adapters/fireworks_tracing.py | 2 +- eval_protocol/tracing/README.md | 6 +++--- eval_protocol/tracing/logprobs.py | 4 ++-- eval_protocol/tracing/types.py | 9 +++++---- tests/tracing/test_registry.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 07cc987c..0ff86f71 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -119,7 +119,7 @@ def convert_trace_dict_to_evaluation_row( if (dp := decoded.get(PayloadType.LOGPROBS)) is not None: logprobs = dp.value - token_ids = dp.extras.get("token_ids") + token_ids = dp.token_ids execution_metadata.extra["completion_logprobs"] = logprobs if token_ids is not None: execution_metadata.extra["completion_token_ids"] = token_ids diff --git a/eval_protocol/tracing/README.md b/eval_protocol/tracing/README.md index 4d1571ae..10a45109 100644 --- a/eval_protocol/tracing/README.md +++ b/eval_protocol/tracing/README.md @@ -51,7 +51,7 @@ if PayloadType.PROMPT_TOKEN_IDS in decoded: if PayloadType.LOGPROBS in decoded: lp = decoded[PayloadType.LOGPROBS] logprobs = lp.value # List[float] - token_ids = lp.extras.get("token_ids") # Optional[List[int]] + token_ids = lp.token_ids # Optional[List[int]] if PayloadType.ROUTER_REPLAY in decoded: matrices = decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] @@ -90,14 +90,14 @@ decode_payloads(payloads, on_error=lambda pt, e: print(f"{pt} failed: {e}")) | `payload_type` | `PayloadType` enum member | | `value` | decoded value (type depends on `payload_type`, see below) | | `metadata` | decoded header/manifest metadata (token counts, scope, etc.) | -| `extras` | type-specific extras (e.g. logprobs `token_ids`) | +| `token_ids` | `Optional[List[int]]` — LOGPROBS per-token ids (else `None`) | `value` by type: | `PayloadType` | `value` | notes | |---------------------|--------------------------|----------------------------------------------| | `PROMPT_TOKEN_IDS` | `List[int]` | prompt token ids | -| `LOGPROBS` | `List[float]` | per completion token; ids in `extras["token_ids"]` (or `None`) | +| `LOGPROBS` | `List[float]` | per completion token; ids in `token_ids` (or `None`) | | `ROUTER_REPLAY` | `List[Optional[str]]` | per-token base64 routing matrices; `None` where absent | ## Adding a new payload type diff --git a/eval_protocol/tracing/logprobs.py b/eval_protocol/tracing/logprobs.py index f5134108..718f24e5 100644 --- a/eval_protocol/tracing/logprobs.py +++ b/eval_protocol/tracing/logprobs.py @@ -115,12 +115,12 @@ def decode_logprobs(data_b64: str) -> DecodedPayload: """Decode a gateway ``payloads.logprobs.data`` blob into a ``DecodedPayload``. ``value`` is the per-completion-token logprob list; per-token ids (when all - valid) are available under ``extras["token_ids"]``. + valid) are available under ``token_ids``. """ logprobs, token_ids, metadata = decompress_and_parse_lp(data_b64) return DecodedPayload( payload_type=PayloadType.LOGPROBS, value=logprobs, metadata=metadata, - extras={"token_ids": token_ids}, + token_ids=token_ids, ) diff --git a/eval_protocol/tracing/types.py b/eval_protocol/tracing/types.py index f3760f28..a11f6412 100644 --- a/eval_protocol/tracing/types.py +++ b/eval_protocol/tracing/types.py @@ -2,9 +2,9 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import Any, Dict +from typing import Any, Dict, List, Optional class PayloadType(str, Enum): @@ -29,7 +29,7 @@ class DecodedPayload: ``value`` shape depends on ``payload_type``: - ``PROMPT_TOKEN_IDS`` -> ``List[int]`` - ``LOGPROBS`` -> ``List[float]`` (per completion token); the - optional per-token ids are in ``extras["token_ids"]`` + optional per-token ids are in ``token_ids`` - ``ROUTER_REPLAY`` -> ``List[Optional[str]]`` (per-token base64 routing matrices, ``None`` where absent) """ @@ -37,4 +37,5 @@ class DecodedPayload: payload_type: PayloadType value: Any metadata: Dict[str, Any] - extras: Dict[str, Any] = field(default_factory=dict) + # LOGPROBS only: per-completion-token ids, or None if any were missing. + token_ids: Optional[List[int]] = None diff --git a/tests/tracing/test_registry.py b/tests/tracing/test_registry.py index b037adad..e32492db 100644 --- a/tests/tracing/test_registry.py +++ b/tests/tracing/test_registry.py @@ -78,7 +78,7 @@ def test_decode_payloads_all_types(): lp = decoded[PayloadType.LOGPROBS] assert lp.value == [-0.25, -0.5] - assert lp.extras["token_ids"] == [7, 8] + assert lp.token_ids == [7, 8] r3 = decoded[PayloadType.ROUTER_REPLAY] assert len(r3.value) == 2