diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 62a632e6..0ff86f71 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -15,9 +15,8 @@ import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from eval_protocol.tracing import PayloadType, decode_payloads from .base import BaseAdapter -from .lp_deserializer import decompress_and_parse_lp -from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -102,45 +101,42 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need - # Extract router replay payloads when present + # Decoding lives in eval_protocol.tracing; here we only map results onto the row. payloads = trace.get("payloads") if isinstance(payloads, dict): - router_replay = payloads.get("router_replay") - if isinstance(router_replay, dict) and router_replay.get("data"): - try: - matrices, r3_meta = decompress_and_parse_r3(router_replay["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["routing_matrices"] = matrices - execution_metadata.extra["routing_metadata"] = r3_meta - except Exception as e: - logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e) - - logprobs_payload = payloads.get("logprobs") - if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"): - try: - logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["completion_logprobs"] = logprobs - if token_ids is not None: - execution_metadata.extra["completion_token_ids"] = token_ids - execution_metadata.extra["logprobs_metadata"] = lp_meta - - for i in range(len(messages) - 1, -1, -1): - if messages[i].role == "assistant": - content_entries = [{"logprob": lp} for lp in logprobs] - if token_ids is not None: - for entry, tid in zip(content_entries, token_ids): - entry["token_id"] = tid - messages[i].logprobs = {"content": content_entries} - break - except Exception as e: - logger.warning( - "Failed to decompress logprobs payload for trace %s: %s", - trace.get("id"), - e, - ) + decoded = decode_payloads( + payloads, + on_error=lambda pt, e: logger.warning( + "Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e + ), + ) + if decoded and execution_metadata.extra is None: + execution_metadata.extra = {} + + if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None: + execution_metadata.extra["routing_matrices"] = dp.value + execution_metadata.extra["routing_metadata"] = dp.metadata + + if (dp := decoded.get(PayloadType.LOGPROBS)) is not None: + logprobs = dp.value + token_ids = dp.token_ids + execution_metadata.extra["completion_logprobs"] = logprobs + if token_ids is not None: + execution_metadata.extra["completion_token_ids"] = token_ids + execution_metadata.extra["logprobs_metadata"] = dp.metadata + + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "assistant": + content_entries = [{"logprob": lp} for lp in logprobs] + if token_ids is not None: + for entry, tid in zip(content_entries, token_ids): + entry["token_id"] = tid + messages[i].logprobs = {"content": content_entries} + break + + if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None: + execution_metadata.extra["prompt_token_ids"] = dp.value + execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata return EvaluationRow( messages=messages, diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 279d1055..e0d5db73 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval for key in ( "completion_logprobs", "completion_token_ids", + "prompt_token_ids", + "prompt_token_ids_metadata", "logprobs_metadata", "routing_matrices", "routing_metadata", diff --git a/eval_protocol/tracing/README.md b/eval_protocol/tracing/README.md new file mode 100644 index 00000000..10a45109 --- /dev/null +++ b/eval_protocol/tracing/README.md @@ -0,0 +1,109 @@ +# `eval_protocol.tracing` — Fireworks tracing-gateway payload decoders + +Standalone helpers for decoding the out-of-band **payloads** the Fireworks +tracing gateway stores alongside a trace (prompt token IDs, completion logprobs, +router-replay routing matrices). + +This package is intentionally self-contained: it depends only on the stdlib and +`zstandard`. It does **not** import `EvaluationRow`, rollout processors, or any +other Eval Protocol machinery, so you can use it even if you are not using EP for +rollouts — just point at it for extracting gateway payloads. + +## What is a "payload"? + +When you read a trace with payloads included: + +``` +GET {gateway}/v1/traces?rollout_id=...&include_payloads=true +``` + +each trace carries a `payloads` object like: + +```json +{ + "payloads": { + "prompt_token_ids": { + "manifest": { "PayloadVersion": "pti/v1", "...": "..." }, + "data": "" + }, + "logprobs": { "manifest": { "PayloadVersion": "lp/v1" }, "data": "..." }, + "router_replay": { "manifest": { "PayloadVersion": "r3/v1" }, "data": "..." } + } +} +``` + +The `data` field is `base64(zstd(raw_bytes))`. Each payload type has its own +`raw_bytes` encoding (`pti/v1` is a JSON int array; `lp/v1` and `r3/v1` are packed +binary). This package hides all of that. + +## Usage + +Decode everything at once (the common case): + +```python +from eval_protocol.tracing import decode_payloads, PayloadType + +decoded = decode_payloads(trace["payloads"]) + +if PayloadType.PROMPT_TOKEN_IDS in decoded: + token_ids = decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + +if PayloadType.LOGPROBS in decoded: + lp = decoded[PayloadType.LOGPROBS] + logprobs = lp.value # List[float] + token_ids = lp.token_ids # Optional[List[int]] + +if PayloadType.ROUTER_REPLAY in decoded: + matrices = decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] +``` + +If you have the whole trace dict, `decode_trace(trace)` reaches into +`trace["payloads"]` for you. + +Decode a single payload: + +```python +from eval_protocol.tracing import decode_payload, PayloadType + +dp = decode_payload(PayloadType.PROMPT_TOKEN_IDS, trace["payloads"]["prompt_token_ids"]["data"]) +dp.value # List[int] +``` + +### Error handling + +`decode_payloads` isolates per-payload failures: if one payload fails to decode, +the others are still returned. Pass `on_error=callback(payload_type, exc)` to +control logging (defaults to a warning): + +```python +decode_payloads(payloads, on_error=lambda pt, e: print(f"{pt} failed: {e}")) +``` + +## Return type + +`decode_payloads` / `decode_trace` return `Dict[PayloadType, DecodedPayload]`. + +`DecodedPayload` fields: + +| field | meaning | +|----------------|-------------------------------------------------------------------| +| `payload_type` | `PayloadType` enum member | +| `value` | decoded value (type depends on `payload_type`, see below) | +| `metadata` | decoded header/manifest metadata (token counts, scope, etc.) | +| `token_ids` | `Optional[List[int]]` — LOGPROBS per-token ids (else `None`) | + +`value` by type: + +| `PayloadType` | `value` | notes | +|---------------------|--------------------------|----------------------------------------------| +| `PROMPT_TOKEN_IDS` | `List[int]` | prompt token ids | +| `LOGPROBS` | `List[float]` | per completion token; ids in `token_ids` (or `None`) | +| `ROUTER_REPLAY` | `List[Optional[str]]` | per-token base64 routing matrices; `None` where absent | + +## Adding a new payload type + +1. Add a member to `PayloadType` in `types.py`. +2. Add a `decode_(data_b64) -> DecodedPayload` function in a new module. +3. Register it in `PAYLOAD_DECODERS` in `registry.py`. + +`decode_payloads` picks it up automatically. diff --git a/eval_protocol/tracing/__init__.py b/eval_protocol/tracing/__init__.py new file mode 100644 index 00000000..7ad9d436 --- /dev/null +++ b/eval_protocol/tracing/__init__.py @@ -0,0 +1,31 @@ +"""Decode Fireworks tracing-gateway payloads. + +Standalone, dependency-light helpers (stdlib + ``zstandard`` only) for turning +the binary/JSON ``payloads`` returned by the Fireworks tracing gateway +(``GET /traces?include_payloads=true``) into Python values. No EvaluationRow or +rollout machinery required -- usable on its own. + +Typical use:: + + from eval_protocol.tracing import decode_payloads, PayloadType + + decoded = decode_payloads(trace["payloads"]) + decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + decoded[PayloadType.LOGPROBS].value # List[float] + decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] + +See ``README.md`` in this package for details. +""" + +from __future__ import annotations + +from .registry import decode_payload, decode_payloads, decode_trace +from .types import DecodedPayload, PayloadType + +__all__ = [ + "PayloadType", + "DecodedPayload", + "decode_payloads", + "decode_payload", + "decode_trace", +] diff --git a/eval_protocol/adapters/lp_deserializer.py b/eval_protocol/tracing/logprobs.py similarity index 86% rename from eval_protocol/adapters/lp_deserializer.py rename to eval_protocol/tracing/logprobs.py index 57aa4f46..718f24e5 100644 --- a/eval_protocol/adapters/lp_deserializer.py +++ b/eval_protocol/tracing/logprobs.py @@ -12,6 +12,8 @@ import zstandard as zstd +from .types import DecodedPayload, PayloadType + MAGIC = b"LP01" HEADER_VERSION = 1 MISSING_TOKEN_ID = -1 @@ -107,3 +109,18 @@ def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[i decompressor = zstd.ZstdDecompressor() raw = decompressor.decompress(compressed) return parse_logprobs(raw) + + +def decode_logprobs(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.logprobs.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-completion-token logprob list; per-token ids (when all + valid) are available under ``token_ids``. + """ + logprobs, token_ids, metadata = decompress_and_parse_lp(data_b64) + return DecodedPayload( + payload_type=PayloadType.LOGPROBS, + value=logprobs, + metadata=metadata, + token_ids=token_ids, + ) diff --git a/eval_protocol/tracing/prompt_token_ids.py b/eval_protocol/tracing/prompt_token_ids.py new file mode 100644 index 00000000..ff9d68de --- /dev/null +++ b/eval_protocol/tracing/prompt_token_ids.py @@ -0,0 +1,34 @@ +"""``pti/v1`` decoder for prompt token ID payloads. + +Inverse of the tracing gateway's ``serialize_prompt_token_ids``: the gateway +stores prompt token IDs as ``base64(zstd(json.dumps(token_ids)))`` -- a compact +JSON int array, no bespoke binary header. +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any, Dict, List, Tuple + +import zstandard as zstd + +from .types import DecodedPayload, PayloadType + + +def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: + """Parse uncompressed ``pti/v1`` bytes (a JSON int array) into ids + metadata.""" + token_ids = json.loads(raw) + metadata: Dict[str, Any] = {"scope": "prompt_only", "token_count": len(token_ids)} + return token_ids, metadata + + +def decode_prompt_token_ids(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.prompt_token_ids.data`` blob.""" + raw = zstd.ZstdDecompressor().decompress(base64.b64decode(data_b64)) + token_ids, metadata = parse_prompt_token_ids(raw) + return DecodedPayload( + payload_type=PayloadType.PROMPT_TOKEN_IDS, + value=token_ids, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/registry.py b/eval_protocol/tracing/registry.py new file mode 100644 index 00000000..62027376 --- /dev/null +++ b/eval_protocol/tracing/registry.py @@ -0,0 +1,88 @@ +"""Decoder registry + master decode for tracing-gateway payloads. + +Adding a new payload type is a single entry in ``PAYLOAD_DECODERS`` (plus its +decoder module). Callers use the master :func:`decode_payloads` / +:func:`decode_trace` and never stitch per-type decoders together. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Optional + +from .logprobs import decode_logprobs +from .prompt_token_ids import decode_prompt_token_ids +from .router_replay import decode_router_replay +from .types import DecodedPayload, PayloadType + +logger = logging.getLogger(__name__) + +# Callback invoked when a single payload fails to decode: (payload_type, exc). +OnError = Callable[[PayloadType, Exception], None] + +PAYLOAD_DECODERS: Dict[PayloadType, Callable[[str], DecodedPayload]] = { + PayloadType.PROMPT_TOKEN_IDS: decode_prompt_token_ids, + PayloadType.LOGPROBS: decode_logprobs, + PayloadType.ROUTER_REPLAY: decode_router_replay, +} + + +def decode_payload(payload_type: PayloadType | str, data_b64: str) -> DecodedPayload: + """Decode a single payload by type. + + ``payload_type`` accepts a ``PayloadType`` or its string value, so external + callers can pass either. + """ + ptype = PayloadType(payload_type) + decoder = PAYLOAD_DECODERS.get(ptype) + if decoder is None: + raise ValueError(f"No decoder registered for payload type: {ptype!r}") + return decoder(data_b64) + + +def decode_payloads( + payloads: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Master decode: run every registered decoder over a gateway ``payloads`` dict. + + Args: + payloads: The ``payloads`` object from a gateway trace (i.e. + ``trace["payloads"]``), mapping payload-type name -> ``{"manifest", "data"}``. + on_error: Optional callback ``(payload_type, exc)`` invoked when a present + payload fails to decode. Defaults to logging a warning. A failure in + one payload never blocks the others. + + Returns: + ``{PayloadType: DecodedPayload}`` for every payload that is present and + decodes successfully. Only known ``PayloadType`` members are considered, + so unknown payload types from a newer gateway are ignored rather than + raising. + """ + if not isinstance(payloads, dict): + return {} + + decoded: Dict[PayloadType, DecodedPayload] = {} + for ptype, decoder in PAYLOAD_DECODERS.items(): + entry = payloads.get(ptype) + if not isinstance(entry, dict) or not entry.get("data"): + continue + try: + decoded[ptype] = decoder(entry["data"]) + except Exception as exc: # noqa: BLE001 - isolate per-payload failures + if on_error is not None: + on_error(ptype, exc) + else: + logger.warning("Failed to decode %s payload: %s", ptype.value, exc) + return decoded + + +def decode_trace( + trace: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Convenience wrapper around :func:`decode_payloads` for a raw trace dict.""" + payloads = trace.get("payloads") if isinstance(trace, dict) else None + return decode_payloads(payloads or {}, on_error=on_error) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/tracing/router_replay.py similarity index 93% rename from eval_protocol/adapters/r3_deserializer.py rename to eval_protocol/tracing/router_replay.py index 1e3c1a8c..49c4be8e 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/tracing/router_replay.py @@ -20,6 +20,8 @@ import zstandard as zstd +from .types import DecodedPayload, PayloadType + MAGIC = b"R3V1" HEADER_FORMAT = "<4sBBBBIIIIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes @@ -185,3 +187,16 @@ def decompress_and_parse_r3( matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") return matrices, metadata + + +def decode_router_replay(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.router_replay.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-token ``List[Optional[str]]`` of base64 routing matrices. + """ + matrices, metadata = decompress_and_parse_r3(data_b64) + return DecodedPayload( + payload_type=PayloadType.ROUTER_REPLAY, + value=matrices, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/types.py b/eval_protocol/tracing/types.py new file mode 100644 index 00000000..a11f6412 --- /dev/null +++ b/eval_protocol/tracing/types.py @@ -0,0 +1,41 @@ +"""Shared types for the Fireworks tracing-gateway payload decoders.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class PayloadType(str, Enum): + """Known out-of-band trace payload types emitted by the tracing gateway. + + Canonical source of truth for payload-type names on the EP side. Mirrors the + gateway's ``rft_tracing.schemas.PayloadType`` but is defined locally so this + package has no dependency on the gateway/mono codebase. Being a ``str`` enum, + members compare and hash equal to their string value, so they can be used + directly against the gateway's string-keyed ``payloads`` JSON. + """ + + PROMPT_TOKEN_IDS = "prompt_token_ids" + LOGPROBS = "logprobs" + ROUTER_REPLAY = "router_replay" + + +@dataclass(frozen=True) +class DecodedPayload: + """A single decoded gateway payload. + + ``value`` shape depends on ``payload_type``: + - ``PROMPT_TOKEN_IDS`` -> ``List[int]`` + - ``LOGPROBS`` -> ``List[float]`` (per completion token); the + optional per-token ids are in ``token_ids`` + - ``ROUTER_REPLAY`` -> ``List[Optional[str]]`` (per-token base64 routing + matrices, ``None`` where absent) + """ + + payload_type: PayloadType + value: Any + metadata: Dict[str, Any] + # LOGPROBS only: per-completion-token ids, or None if any were missing. + token_ids: Optional[List[int]] = None diff --git a/scripts/test_remote_rollout_prompt_token_ids.py b/scripts/test_remote_rollout_prompt_token_ids.py new file mode 100644 index 00000000..a143772f --- /dev/null +++ b/scripts/test_remote_rollout_prompt_token_ids.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads. + +This starts a tiny local `/init` server, sends one chat completion through the +Fireworks tracing gateway with `return_token_ids`, and verifies that +RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import socket +import threading +import time +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI + +from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger("remote_rollout_prompt_token_ids") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]: + if isinstance(message, Message): + return message.dump_mdoel_for_chat_completion_request() + return {k: v for k, v in dict(message).items() if v is not None} + + +def _make_app(gateway_url: str) -> FastAPI: + app = FastAPI() + app_logger = logging.getLogger(f"{__name__}.server") + app_logger.setLevel(logging.INFO) + + @app.get("/") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/init") + def init(req: InitRequest) -> dict[str, str]: + rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers): + rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url)) + rollout_logger.setLevel(logging.INFO) + + def _worker() -> None: + try: + conversation = [_message_to_dict(message) for message in (req.messages or [])] + params = dict(req.completion_params or {}) + params.pop("base_url", None) + params["extra_body"] = { + **dict(params.get("extra_body") or {}), + "return_token_ids": True, + } + params.setdefault("temperature", 0) + params.setdefault("max_tokens", 8) + + if not req.model_base_url: + raise ValueError("model_base_url is required") + if not params.get("model"): + raise ValueError("completion_params.model is required") + + client = OpenAI(base_url=req.model_base_url, api_key=req.api_key) + response = client.chat.completions.create(messages=conversation, **params) + content = response.choices[0].message.content or "" + logger.info("remote server generated content=%r", content) + + rollout_logger.info( + "rollout %s finished", + req.metadata.rollout_id, + extra={"status": Status.rollout_finished()}, + ) + except Exception as exc: + rollout_logger.exception( + "rollout %s failed", + req.metadata.rollout_id, + extra={"status": Status.rollout_unknown_error(str(exc))}, + ) + + threading.Thread(target=_worker, daemon=True).start() + return {"status": "started"} + + return app + + +def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None: + import requests + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + return + except Exception: + pass + time.sleep(0.2) + raise TimeoutError(f"server not ready: {url}") + + +async def _run(args: argparse.Namespace) -> None: + api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY") + if not api_key: + raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY") + + # FireworksTracingHttpHandler reads FIREWORKS_API_KEY. + os.environ["FIREWORKS_API_KEY"] = api_key + os.environ["EP_REMOTE_API_KEY"] = api_key + + port = args.port or _free_port() + remote_base_url = f"http://127.0.0.1:{port}" + app = _make_app(args.gateway_url) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + _wait_ready(f"{remote_base_url}/") + + rollout_id = f"rrp-prompt-ids-{int(time.time())}" + row = EvaluationRow( + messages=[Message(role="user", content="Reply with exactly: ok")], + ) + row.input_metadata.row_id = "row-0" + row.input_metadata.completion_params = { + "model": args.model, + "base_url": args.api_base_url, + "temperature": 0, + "max_tokens": 8, + } + row.execution_metadata.rollout_id = rollout_id + row.execution_metadata.invocation_id = "inv-0" + row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e" + row.execution_metadata.run_id = "run-0" + + processor = RemoteRolloutProcessor( + remote_base_url=remote_base_url, + model_base_url=args.gateway_url, + include_payloads=True, + timeout_seconds=args.timeout_seconds, + poll_interval=args.poll_interval, + ) + try: + task = processor( + [row], + RolloutProcessorConfig( + completion_params=row.input_metadata.completion_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(1), + steps=1, + ), + )[0] + completed = await task + finally: + await processor.acleanup() + server.should_exit = True + thread.join(timeout=5) + + extra = completed.execution_metadata.extra or {} + turn_payloads = extra.get("assistant_turn_payloads") or [] + prompt_ids = None + if turn_payloads: + prompt_ids = turn_payloads[0].get("prompt_token_ids") + if prompt_ids is None: + prompt_ids = extra.get("prompt_token_ids") + + print(f"rollout_id={rollout_id}") + print(f"messages={len(completed.messages)}") + print(f"assistant_turn_payloads={turn_payloads}") + print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}") + print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}") + + if not isinstance(prompt_ids, list) or not prompt_ids: + raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app")) + parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1")) + parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly")) + parser.add_argument("--api-key", default=None) + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--timeout-seconds", type=float, default=180.0) + parser.add_argument("--poll-interval", type=float, default=2.0) + asyncio.run(_run(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/tests/adapters/test_fireworks_tracing_logprobs.py b/tests/adapters/test_fireworks_tracing_logprobs.py index 08dab60b..c38c7acd 100644 --- a/tests/adapters/test_fireworks_tracing_logprobs.py +++ b/tests/adapters/test_fireworks_tracing_logprobs.py @@ -11,7 +11,7 @@ pytest.importorskip("mcp") from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py new file mode 100644 index 00000000..16bebe3f --- /dev/null +++ b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py @@ -0,0 +1,55 @@ +"""Tests for prompt token ID payload handling in fireworks_tracing adapter.""" + +from __future__ import annotations + +import base64 +import json + +import pytest +import zstandard as zstd + +pytest.importorskip("mcp") + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.tracing.prompt_token_ids import decode_prompt_token_ids + + +def _pti_b64(token_ids: list[int]) -> str: + """Build a gateway pti/v1 payload: base64(zstd(json int array)).""" + raw = json.dumps(token_ids).encode("utf-8") + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +def test_decode_prompt_token_ids_round_trip(): + decoded = decode_prompt_token_ids(_pti_b64([101, 102, 103])) + + assert decoded.value == [101, 102, 103] + assert decoded.metadata["scope"] == "prompt_only" + assert decoded.metadata["token_count"] == 3 + + +def test_trace_adapter_attaches_prompt_token_ids_metadata(): + trace = { + "id": "trace-pti", + "input": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, + "output": {"role": "assistant", "content": "hello"}, + "payloads": { + "prompt_token_ids": { + "data": _pti_b64([201, 202, 203]), + "manifest": {"PayloadVersion": "pti/v1"}, + }, + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + + assert row is not None + extra = row.execution_metadata.extra + assert extra is not None + assert extra["prompt_token_ids"] == [201, 202, 203] + assert extra["prompt_token_ids_metadata"]["token_count"] == 3 diff --git a/tests/adapters/test_lp_deserializer.py b/tests/adapters/test_lp_deserializer.py index 52e04417..c5460614 100644 --- a/tests/adapters/test_lp_deserializer.py +++ b/tests/adapters/test_lp_deserializer.py @@ -8,7 +8,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 31f058f6..980d3934 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -10,7 +10,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.r3_deserializer import ( +from eval_protocol.tracing.router_replay import ( HEADER_FORMAT, HEADER_SIZE, MAGIC, diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py index 58ec55c1..98246ab1 100644 --- a/tests/pytest/test_tracing_utils.py +++ b/tests/pytest/test_tracing_utils.py @@ -13,6 +13,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -28,6 +29,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -45,12 +47,14 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): { "assistant_turn_index": 0, "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, { "assistant_turn_index": 1, "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, diff --git a/tests/tracing/test_registry.py b/tests/tracing/test_registry.py new file mode 100644 index 00000000..e32492db --- /dev/null +++ b/tests/tracing/test_registry.py @@ -0,0 +1,124 @@ +"""Tests for the standalone tracing-gateway payload decoder registry.""" + +from __future__ import annotations + +import base64 +import json +import struct + +import zstandard as zstd + +from eval_protocol.tracing import ( + DecodedPayload, + PayloadType, + decode_payload, + decode_payloads, + decode_trace, +) +from eval_protocol.tracing import logprobs as lp_mod +from eval_protocol.tracing import router_replay as r3_mod + + +def _b64_zstd(raw: bytes) -> str: + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +def _pti_data(token_ids: list[int]) -> str: + return _b64_zstd(json.dumps(token_ids).encode("utf-8")) + + +def _lp_data(tokens: list[tuple[int, float]]) -> str: + body = b"".join(struct.pack(lp_mod.ENTRY_FORMAT, tid, lp) for tid, lp in tokens) + header = struct.pack( + lp_mod.HEADER_FORMAT, lp_mod.MAGIC, 1, 0, 0, len(tokens), len(body), 0 + ) + return _b64_zstd(header + body) + + +def _r3_data_all_mode(matrices: list[bytes]) -> str: + matrix_data = b"".join(matrices) + header = struct.pack( + r3_mod.HEADER_FORMAT, + r3_mod.MAGIC, + 1, # version + r3_mod._SelectorMode.ALL, + r3_mod._RoutingDtype.UINT8, + 0x01, # flags + len(matrices), # total_token_count + len(matrices), # replayed_token_count + 0, # replay_start_token + 0, # selector_byte_length + len(matrix_data), + ) + return _b64_zstd(header + matrix_data) + + +def _all_payloads() -> dict: + return { + "prompt_token_ids": {"manifest": {"PayloadVersion": "pti/v1"}, "data": _pti_data([1, 2, 3])}, + "logprobs": {"manifest": {"PayloadVersion": "lp/v1"}, "data": _lp_data([(7, -0.25), (8, -0.5)])}, + "router_replay": { + "manifest": {"PayloadVersion": "r3/v1"}, + "data": _r3_data_all_mode([b"\x01\x02\x03\x04", b"\x05\x06\x07\x08"]), + }, + } + + +def test_decode_payloads_all_types(): + decoded = decode_payloads(_all_payloads()) + + assert set(decoded) == { + PayloadType.PROMPT_TOKEN_IDS, + PayloadType.LOGPROBS, + PayloadType.ROUTER_REPLAY, + } + assert all(isinstance(dp, DecodedPayload) for dp in decoded.values()) + + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [1, 2, 3] + + lp = decoded[PayloadType.LOGPROBS] + assert lp.value == [-0.25, -0.5] + assert lp.token_ids == [7, 8] + + r3 = decoded[PayloadType.ROUTER_REPLAY] + assert len(r3.value) == 2 + assert base64.b64decode(r3.value[0]) == b"\x01\x02\x03\x04" + + +def test_decode_payload_accepts_str_and_enum(): + data = _pti_data([10, 20]) + via_enum = decode_payload(PayloadType.PROMPT_TOKEN_IDS, data) + via_str = decode_payload("prompt_token_ids", data) + assert via_enum.value == via_str.value == [10, 20] + + +def test_decode_trace_reaches_into_payloads(): + trace = {"id": "t1", "payloads": {"prompt_token_ids": {"data": _pti_data([5, 6])}}} + decoded = decode_trace(trace) + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [5, 6] + + +def test_unknown_and_empty_types_are_skipped(): + payloads = { + "some_future_type": {"data": "ignored"}, # unknown -> ignored + "logprobs": {"data": ""}, # present but empty -> skipped + "prompt_token_ids": {"data": _pti_data([9])}, + } + decoded = decode_payloads(payloads) + assert set(decoded) == {PayloadType.PROMPT_TOKEN_IDS} + + +def test_on_error_fires_on_bad_data(): + errors = [] + payloads = {"prompt_token_ids": {"data": "not-valid-base64-zstd-json!!"}} + + decoded = decode_payloads(payloads, on_error=lambda pt, e: errors.append((pt, e))) + + assert decoded == {} + assert len(errors) == 1 + assert errors[0][0] == PayloadType.PROMPT_TOKEN_IDS + + +def test_decode_payloads_non_dict_returns_empty(): + assert decode_payloads(None) == {} + assert decode_trace({"id": "no-payloads"}) == {}