From d4247e51f861b5a5eeb69c6e9b7816d33010da5d Mon Sep 17 00:00:00 2001 From: doxav <> Date: Mon, 8 Jun 2026 17:35:42 +0200 Subject: [PATCH 1/2] Implement helper to add token usage tracking and validation in objectives and guides - Added `UsageTrackingLLM` class to wrap LLMs and record token usage. - Introduced `TokenUsageAugmentingGuide` to integrate token metrics into existing guides. - Enhanced `ObjectiveConfig` to include required metrics validation. - Implemented tests for required metrics in objectives and token usage in guides. - Created JSON output for multi-objective token usage results. --- README.md | 3 +- ...ultiobjective_token_usage_gsm8k_demo.ipynb | 451 ++++++++++++++++++ ...ective_token_usage_gsm8k_demo_results.json | 56 +++ opto/trainer/guide.py | 337 +++++++++++++ opto/trainer/objectives.py | 55 ++- .../test_objectives_required_metrics.py | 74 +++ tests/unit_tests/test_token_usage_guide.py | 206 ++++++++ 7 files changed, 1176 insertions(+), 6 deletions(-) create mode 100644 examples/notebooks/multiobjective_token_usage_gsm8k_demo.ipynb create mode 100644 examples/notebooks/notebook_outputs/multiobjective_token_usage_gsm8k_demo_results.json create mode 100644 tests/unit_tests/test_objectives_required_metrics.py create mode 100644 tests/unit_tests/test_token_usage_guide.py diff --git a/README.md b/README.md index e46a2e00..28c6d7c5 100644 --- a/README.md +++ b/README.md @@ -253,7 +253,7 @@ Key features: - **Vector scores** — `Guide.get_score_dict()` returns `Dict[str, float]` with named metrics - **Weighted scalarization** and **Pareto dominance** ranking via `ObjectiveConfig` - Supported in `BasicSearchAlgorithm`, `BeamsearchAlgorithm`, and `BeamsearchHistoryAlgorithm` -- Token-minimization pattern using `UsageTrackingLLM` + `TokenUsageAugmentingGuide` +- **Token-aware objectives** — wrap an agent LLM with `UsageTrackingLLM`, wrap the task guide with `TokenUsageAugmentingGuide`, then set `ObjectiveConfig.required_metrics` so runs fail early if `error`, `tokens_in`, or `tokens_out` are missing. This supports goals such as "solve GSM8K correctly while minimizing prompt and completion tokens." Canonical notebooks: @@ -262,6 +262,7 @@ Canonical notebooks: | [multiobjective_quickstart](examples/notebooks/multiobjective_quickstart.ipynb) | Core vector-score infrastructure and BasicSearch integration | | [multiobjective_trainers](examples/notebooks/multiobjective_trainers.ipynb) | Beamsearch and PrioritySearch multi-objective support | | [multiobjective_bbeh_langgraph](examples/notebooks/multiobjective_bbeh_langgraph.ipynb) | Real LLM task: BBEH boolean expressions with accuracy + execution time | +| [multiobjective_token_usage_gsm8k_demo](examples/notebooks/multiobjective_token_usage_gsm8k_demo.ipynb) | GSM8K-style objective that minimizes answer error, prompt tokens, and completion tokens | Trace-Bench multi-objective benchmarks (in [AgentOpt/Trace-Bench](https://github.com/AgentOpt/Trace-Bench)): diff --git a/examples/notebooks/multiobjective_token_usage_gsm8k_demo.ipynb b/examples/notebooks/multiobjective_token_usage_gsm8k_demo.ipynb new file mode 100644 index 00000000..bde122d8 --- /dev/null +++ b/examples/notebooks/multiobjective_token_usage_gsm8k_demo.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1fb6d2f3", + "metadata": {}, + "source": [ + "# Token-aware GSM8K-style objective demo\n", + "\n", + "This notebook demonstrates token-aware multi-objective evaluation on a small GSM8K-style arithmetic problem set. It runs fully offline with a deterministic OpenAI-shaped LLM stub so the saved outputs are reproducible.\n", + "\n", + "Agent goals:\n", + "- answer the math problem correctly;\n", + "- when correctness ties, prefer fewer prompt and completion tokens;\n", + "- fail early if the configured objective requires token metrics that the guide did not emit." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d9dd5be0", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:52.736139Z", + "iopub.status.busy": "2026-06-08T15:15:52.736035Z", + "iopub.status.idle": "2026-06-08T15:15:53.959336Z", + "shell.execute_reply": "2026-06-08T15:15:53.958652Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "repo root: /home/xav/code/Trace\n" + ] + } + ], + "source": [ + "import json\n", + "import re\n", + "import sys\n", + "from pathlib import Path\n", + "from typing import Any, Dict, List, Optional, Tuple\n", + "\n", + "ROOT = next(candidate for candidate in [Path.cwd(), *Path.cwd().parents] if (candidate / 'opto').exists())\n", + "if str(ROOT) not in sys.path:\n", + " sys.path.insert(0, str(ROOT))\n", + "\n", + "from opto.features.predefined_agents import BasicLearner\n", + "from opto.trainer.evaluators import aggregate_vector_scores, evaluate_vector\n", + "from opto.trainer.guide import Guide, TokenUsageAugmentingGuide, UsageTrackingLLM\n", + "from opto.trainer.objectives import ObjectiveConfig, select_best\n", + "\n", + "OUTPUT_DIR = ROOT / 'examples' / 'notebooks' / 'notebook_outputs'\n", + "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", + "print('repo root:', ROOT)" + ] + }, + { + "cell_type": "markdown", + "id": "db191ff2", + "metadata": {}, + "source": [ + "## Dataset and deterministic LLM\n", + "\n", + "The questions are GSM8K-style word problems. The LLM stub returns exact arithmetic answers and OpenAI-compatible `usage` metadata so `UsageTrackingLLM` can record tokens without estimating." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8ddcbe30", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:53.961237Z", + "iopub.status.busy": "2026-06-08T15:15:53.961075Z", + "iopub.status.idle": "2026-06-08T15:15:53.968296Z", + "shell.execute_reply": "2026-06-08T15:15:53.967627Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'inputs': ['Maria has 3 packs of 4 pencils. How many pencils does she have?',\n", + " 'Tom has 10 apples and gives away 6. How many apples are left?',\n", + " 'A sticker sheet has 5 rows with 7 stickers each. How many stickers are there?'],\n", + " 'infos': ['12', '4', '35']}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class _Usage:\n", + " def __init__(self, prompt_tokens: int, completion_tokens: int) -> None:\n", + " self.prompt_tokens = prompt_tokens\n", + " self.completion_tokens = completion_tokens\n", + "\n", + "\n", + "class _Message:\n", + " def __init__(self, content: str) -> None:\n", + " self.content = content\n", + "\n", + "\n", + "class _Choice:\n", + " def __init__(self, content: str) -> None:\n", + " self.message = _Message(content)\n", + "\n", + "\n", + "class _Response:\n", + " def __init__(self, content: str, prompt_tokens: int, completion_tokens: int) -> None:\n", + " self.choices = [_Choice(content)]\n", + " self.usage = _Usage(prompt_tokens, completion_tokens)\n", + "\n", + "\n", + "class RuleBasedGSM8KLLM:\n", + " \"\"\"Small deterministic LLM replacement for a GSM8K-style token demo.\"\"\"\n", + "\n", + " def __call__(self, *, messages: List[Dict[str, str]], **kwargs: Any) -> _Response:\n", + " system_prompt = messages[0]['content'].lower()\n", + " question = messages[-1]['content']\n", + " answer = self._solve(question)\n", + " if 'final number only' in system_prompt or 'concise' in system_prompt:\n", + " content = str(answer)\n", + " else:\n", + " content = f'We identify the numbers, choose the operation, and compute carefully. The answer is {answer}.'\n", + " prompt_tokens = self._count_words(' '.join(message['content'] for message in messages))\n", + " completion_tokens = self._count_words(content)\n", + " return _Response(content, prompt_tokens, completion_tokens)\n", + "\n", + " @staticmethod\n", + " def _count_words(text: str) -> int:\n", + " return len(text.split())\n", + "\n", + " @staticmethod\n", + " def _solve(question: str) -> int:\n", + " numbers = [int(value) for value in re.findall(r'\\d+', question)]\n", + " lowered = question.lower()\n", + " if any(marker in lowered for marker in ('each', 'packs of', 'rows')):\n", + " return numbers[0] * numbers[1]\n", + " if any(marker in lowered for marker in ('gives away', 'left', 'remain')):\n", + " return numbers[0] - numbers[1]\n", + " return sum(numbers)\n", + "\n", + "\n", + "DATASET = {\n", + " 'inputs': [\n", + " 'Maria has 3 packs of 4 pencils. How many pencils does she have?',\n", + " 'Tom has 10 apples and gives away 6. How many apples are left?',\n", + " 'A sticker sheet has 5 rows with 7 stickers each. How many stickers are there?',\n", + " ],\n", + " 'infos': ['12', '4', '35'],\n", + "}\n", + "DATASET" + ] + }, + { + "cell_type": "markdown", + "id": "b7db5b25", + "metadata": {}, + "source": [ + "## Guide and objective\n", + "\n", + "`ExactNumberGuide` emits the task metric (`error`). `TokenUsageAugmentingGuide` adds `tokens_in` and `tokens_out` from the shared tracked LLM. The objective requires all three metrics and minimizes them." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d0d99acd", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:53.969956Z", + "iopub.status.busy": "2026-06-08T15:15:53.969820Z", + "iopub.status.idle": "2026-06-08T15:15:53.974220Z", + "shell.execute_reply": "2026-06-08T15:15:53.973566Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "ObjectiveConfig(mode='weighted', weights={'error': 1.0, 'tokens_in': 0.001, 'tokens_out': 0.001}, minimize=frozenset({'error', 'tokens_out', 'tokens_in'}), missing_value=-inf, pareto_metrics=None, tie_break='weighted', required_metrics=frozenset({'error', 'tokens_out', 'tokens_in'}), seed=0, scalarize_dict='score', score_key='score')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class ExactNumberGuide(Guide):\n", + " \"\"\"Score GSM8K-style answers by exact final integer match.\"\"\"\n", + "\n", + " def get_feedback(\n", + " self,\n", + " query: str,\n", + " response: str,\n", + " reference: Optional[str] = None,\n", + " **kwargs: Any,\n", + " ) -> Tuple[float, str]:\n", + " predicted = self._extract_final_number(response)\n", + " expected = '' if reference is None else str(reference).strip()\n", + " correct = predicted == expected\n", + " return float(correct), f'predicted={predicted!r}; expected={expected!r}'\n", + "\n", + " def get_score_dict(\n", + " self,\n", + " query: str,\n", + " response: str,\n", + " reference: Optional[str] = None,\n", + " **kwargs: Any,\n", + " ) -> Dict[str, float]:\n", + " reward, _ = self.get_feedback(query, response, reference, **kwargs)\n", + " return {'error': 1.0 - reward}\n", + "\n", + " @staticmethod\n", + " def _extract_final_number(text: str) -> str:\n", + " matches = re.findall(r'-?\\d+', text)\n", + " return matches[-1] if matches else ''\n", + "\n", + "\n", + "objective = ObjectiveConfig(\n", + " mode='weighted',\n", + " weights={'error': 1.0, 'tokens_in': 1e-3, 'tokens_out': 1e-3},\n", + " minimize=frozenset({'error', 'tokens_in', 'tokens_out'}),\n", + " required_metrics=frozenset({'error', 'tokens_in', 'tokens_out'}),\n", + ")\n", + "objective" + ] + }, + { + "cell_type": "markdown", + "id": "b8894a2a", + "metadata": {}, + "source": [ + "## Compare verbose and concise agent goals\n", + "\n", + "Both agents solve the examples. The multi-objective selector should choose the concise goal because it has lower token usage at the same error." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8ff3d10", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:53.976121Z", + "iopub.status.busy": "2026-06-08T15:15:53.975939Z", + "iopub.status.idle": "2026-06-08T15:15:53.985424Z", + "shell.execute_reply": "2026-06-08T15:15:53.984956Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating: Solve carefully and explain the (Running sequentially).\n", + "Evaluating: Answer with the final number onl (Running sequentially).\n", + "{\n", + " \"runs\": [\n", + " {\n", + " \"system_prompt\": \"Solve carefully and explain the reasoning before the final answer.\",\n", + " \"per_example\": [\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 23.0,\n", + " \"tokens_out\": 14.0\n", + " },\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 23.0,\n", + " \"tokens_out\": 14.0\n", + " },\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 25.0,\n", + " \"tokens_out\": 14.0\n", + " }\n", + " ],\n", + " \"aggregate\": {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 23.666666666666668,\n", + " \"tokens_out\": 14.0\n", + " }\n", + " },\n", + " {\n", + " \"system_prompt\": \"Answer with the final number only. Be concise.\",\n", + " \"per_example\": [\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 21.0,\n", + " \"tokens_out\": 1.0\n", + " },\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 21.0,\n", + " \"tokens_out\": 1.0\n", + " },\n", + " {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 23.0,\n", + " \"tokens_out\": 1.0\n", + " }\n", + " ],\n", + " \"aggregate\": {\n", + " \"error\": 0.0,\n", + " \"tokens_in\": 21.666666666666668,\n", + " \"tokens_out\": 1.0\n", + " }\n", + " }\n", + " ],\n", + " \"selected_index\": 1,\n", + " \"selected_goal\": \"Answer with the final number only. Be concise.\"\n", + "}\n" + ] + } + ], + "source": [ + "def evaluate_goal(system_prompt: str) -> Dict[str, Any]:\n", + " \"\"\"Evaluate one agent goal and return per-example and aggregate metrics.\"\"\"\n", + " tracked_llm = UsageTrackingLLM(RuleBasedGSM8KLLM(), estimate_missing=False)\n", + " agent = BasicLearner(\n", + " system_prompt=system_prompt,\n", + " user_prompt_template='{message}',\n", + " llm=tracked_llm,\n", + " )\n", + " guide = TokenUsageAugmentingGuide(ExactNumberGuide(), tracked_llm)\n", + " per_example = evaluate_vector(\n", + " agent,\n", + " guide,\n", + " DATASET['inputs'],\n", + " DATASET['infos'],\n", + " num_threads=1,\n", + " description=f'Evaluating: {system_prompt[:32]}',\n", + " )\n", + " aggregate = aggregate_vector_scores(per_example)\n", + " return {'system_prompt': system_prompt, 'per_example': per_example, 'aggregate': aggregate}\n", + "\n", + "\n", + "runs = [\n", + " evaluate_goal('Solve carefully and explain the reasoning before the final answer.'),\n", + " evaluate_goal('Answer with the final number only. Be concise.'),\n", + "]\n", + "best_index = select_best([(run['aggregate'], run) for run in runs], objective)\n", + "result = {'runs': runs, 'selected_index': best_index, 'selected_goal': runs[best_index]['system_prompt']}\n", + "print(json.dumps(result, indent=2))" + ] + }, + { + "cell_type": "markdown", + "id": "7dc4caa4", + "metadata": {}, + "source": [ + "## Required metric guard\n", + "\n", + "If a token objective is configured but token metrics are missing, selection fails before silently optimizing the wrong scalar." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7030d958", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:53.986884Z", + "iopub.status.busy": "2026-06-08T15:15:53.986784Z", + "iopub.status.idle": "2026-06-08T15:15:53.988944Z", + "shell.execute_reply": "2026-06-08T15:15:53.988570Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ValueError: Missing required objective metrics: ['tokens_out']. Available metrics: ['error', 'tokens_in']\n" + ] + } + ], + "source": [ + "try:\n", + " select_best([({'error': 0.0, 'tokens_in': 10.0}, 'missing tokens_out')], objective)\n", + "except ValueError as exc:\n", + " print(type(exc).__name__ + ':', exc)" + ] + }, + { + "cell_type": "markdown", + "id": "4a467a6c", + "metadata": {}, + "source": [ + "## Save results" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4f853005", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-08T15:15:53.990216Z", + "iopub.status.busy": "2026-06-08T15:15:53.990135Z", + "iopub.status.idle": "2026-06-08T15:15:53.992774Z", + "shell.execute_reply": "2026-06-08T15:15:53.992112Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saved: examples/notebooks/notebook_outputs/multiobjective_token_usage_gsm8k_demo_results.json\n" + ] + } + ], + "source": [ + "output_path = OUTPUT_DIR / 'multiobjective_token_usage_gsm8k_demo_results.json'\n", + "with output_path.open('w', encoding='utf-8') as handle:\n", + " json.dump(result, handle, indent=2)\n", + "print('saved:', output_path.relative_to(ROOT))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/notebook_outputs/multiobjective_token_usage_gsm8k_demo_results.json b/examples/notebooks/notebook_outputs/multiobjective_token_usage_gsm8k_demo_results.json new file mode 100644 index 00000000..e4dd9562 --- /dev/null +++ b/examples/notebooks/notebook_outputs/multiobjective_token_usage_gsm8k_demo_results.json @@ -0,0 +1,56 @@ +{ + "runs": [ + { + "system_prompt": "Solve carefully and explain the reasoning before the final answer.", + "per_example": [ + { + "error": 0.0, + "tokens_in": 23.0, + "tokens_out": 14.0 + }, + { + "error": 0.0, + "tokens_in": 23.0, + "tokens_out": 14.0 + }, + { + "error": 0.0, + "tokens_in": 25.0, + "tokens_out": 14.0 + } + ], + "aggregate": { + "error": 0.0, + "tokens_in": 23.666666666666668, + "tokens_out": 14.0 + } + }, + { + "system_prompt": "Answer with the final number only. Be concise.", + "per_example": [ + { + "error": 0.0, + "tokens_in": 21.0, + "tokens_out": 1.0 + }, + { + "error": 0.0, + "tokens_in": 21.0, + "tokens_out": 1.0 + }, + { + "error": 0.0, + "tokens_in": 23.0, + "tokens_out": 1.0 + } + ], + "aggregate": { + "error": 0.0, + "tokens_in": 21.666666666666668, + "tokens_out": 1.0 + } + } + ], + "selected_index": 1, + "selected_goal": "Answer with the final number only. Be concise." +} \ No newline at end of file diff --git a/opto/trainer/guide.py b/opto/trainer/guide.py index 4906c831..30a9eef9 100644 --- a/opto/trainer/guide.py +++ b/opto/trainer/guide.py @@ -1,4 +1,5 @@ from typing import List, Dict, Any, Union, Tuple, Optional, Callable +import contextvars import json import pickle import re @@ -7,6 +8,7 @@ from opto.utils.llm import LLM, AbstractModel from opto.trainer.suggest import Suggest + def exact_match_metric(question, student_answer, info): """ Exact match metric """ return float(student_answer == info) @@ -97,6 +99,341 @@ def __setstate__(self, state): self.__dict__.update(state) +class UsageTrackingLLM: + """Wrap an LLM and record token usage for the current execution context. + + The wrapper reads OpenAI-compatible usage metadata when available: + ``response.usage.prompt_tokens`` and ``response.usage.completion_tokens``. + Dict-shaped responses and ``input_tokens`` / ``output_tokens`` aliases are + also supported. + + If usage metadata is missing, the wrapper estimates token counts with a + simple whitespace split by default. Set ``estimate_missing=False`` to fail + fast instead. Streaming usage is only captured when the provider returns + usage on the response object passed back from the wrapped LLM. + + The usage store is ``ContextVar``-backed, so concurrent evaluations do not + overwrite each other. ``__deepcopy__`` returns ``self`` intentionally: + trainer code may deep-copy agent and guide independently, and both copies + must keep sharing the same tracker instance. + """ + + def __init__(self, base_llm: Any, *, estimate_missing: bool = True) -> None: + if not callable(base_llm): + raise TypeError("base_llm must be callable") + self._base = base_llm + self.estimate_missing = estimate_missing + self._usage: contextvars.ContextVar[Optional[Dict[str, int]]] = ( + contextvars.ContextVar("llm_token_usage", default=None) + ) + self._usage_estimated: contextvars.ContextVar[bool] = contextvars.ContextVar( + "llm_token_usage_estimated", default=False + ) + + def __deepcopy__(self, memo: Dict[int, Any]) -> "UsageTrackingLLM": + memo[id(self)] = self + return self + + def __getattr__(self, name: str) -> Any: + try: + base = self.__dict__["_base"] + except KeyError as exc: + raise AttributeError(name) from exc + return getattr(base, name) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + messages = kwargs.get("messages", args[0] if args else None) + response = self._base(*args, **kwargs) + tokens_in, tokens_out, estimated = self._read_token_usage( + response=response, + messages=messages, + ) + self._add_usage(tokens_in=tokens_in, tokens_out=tokens_out) + self._usage_estimated.set(self._usage_estimated.get() or estimated) + return response + + def reset_usage(self) -> None: + """Clear token usage for the current execution context.""" + self._usage.set(None) + self._usage_estimated.set(False) + + def has_usage(self) -> bool: + """Return whether any usage was recorded in this execution context.""" + return self._usage.get() is not None + + def last_usage(self, *, reset: bool = False) -> Dict[str, int]: + """Return accumulated token usage for the current execution context. + + Args: + reset: Clear the current context after reading when True. + """ + usage = self._usage.get() + result = ( + {"tokens_in": 0, "tokens_out": 0} + if usage is None + else { + "tokens_in": int(usage.get("tokens_in", 0)), + "tokens_out": int(usage.get("tokens_out", 0)), + } + ) + if reset: + self.reset_usage() + return result + + def last_usage_was_estimated(self) -> bool: + """Return whether any current usage metric used fallback estimation.""" + return bool(self._usage_estimated.get()) + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_usage"] = None + state["_usage_estimated"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self._usage = contextvars.ContextVar("llm_token_usage", default=None) + self._usage_estimated = contextvars.ContextVar( + "llm_token_usage_estimated", default=False + ) + + def _add_usage(self, *, tokens_in: int, tokens_out: int) -> None: + previous = self._usage.get() or {"tokens_in": 0, "tokens_out": 0} + self._usage.set( + { + "tokens_in": int(previous.get("tokens_in", 0)) + tokens_in, + "tokens_out": int(previous.get("tokens_out", 0)) + tokens_out, + } + ) + + def _read_token_usage( + self, + *, + response: Any, + messages: Any, + ) -> Tuple[int, int, bool]: + prompt_tokens, completion_tokens = self._extract_usage(response) + estimated = False + + if prompt_tokens is None or completion_tokens is None: + if not self.estimate_missing: + raise ValueError( + "LLM response did not include complete token usage. " + "Use UsageTrackingLLM(..., estimate_missing=True) to " + "allow estimates, or use a provider/model that returns " + "response.usage." + ) + estimated = True + if prompt_tokens is None: + prompt_tokens = self._estimate_prompt_tokens(messages) + if completion_tokens is None: + completion_tokens = self._estimate_completion_tokens(response) + + return int(prompt_tokens or 0), int(completion_tokens or 0), estimated + + @classmethod + def _extract_usage(cls, response: Any) -> Tuple[Optional[int], Optional[int]]: + usage = cls._get_field(response, "usage") + prompt_tokens = cls._to_optional_int( + cls._get_field(usage, "prompt_tokens", "input_tokens") + ) + completion_tokens = cls._to_optional_int( + cls._get_field(usage, "completion_tokens", "output_tokens") + ) + return prompt_tokens, completion_tokens + + @staticmethod + def _get_field(obj: Any, *names: str) -> Any: + if obj is None: + return None + for name in names: + value = None + if isinstance(obj, dict) and name in obj: + value = obj[name] + elif hasattr(obj, name): + value = getattr(obj, name) + if value is not None: + return value + return None + + @staticmethod + def _to_optional_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @classmethod + def _estimate_prompt_tokens(cls, messages: Any) -> int: + if not isinstance(messages, list): + return 0 + return len( + " ".join(cls._message_to_text(message) for message in messages).split() + ) + + @classmethod + def _message_to_text(cls, message: Any) -> str: + if isinstance(message, dict): + return cls._content_to_text(message.get("content", "")) + return str(message) + + @classmethod + def _content_to_text(cls, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + return " ".join(cls._content_to_text(part) for part in content) + if isinstance(content, dict): + parts = [] + for key in ("text", "content"): + if key in content: + parts.append(cls._content_to_text(content[key])) + return " ".join(parts) + return str(content) + + @classmethod + def _estimate_completion_tokens(cls, response: Any) -> int: + text = cls._completion_text(response) + return len(text.split()) if text else 0 + + @classmethod + def _completion_text(cls, response: Any) -> str: + if isinstance(response, str): + return response + + choices = cls._get_field(response, "choices") + if not choices: + return "" + + choice = choices[0] + message = cls._get_field(choice, "message") + content = cls._get_field(message, "content") + if content is None: + content = cls._get_field(choice, "text") + return "" if content is None else str(content) + + +class TokenUsageAugmentingGuide(Guide): + """Add token usage metrics from a tracker to another guide's scores. + + The tracker must expose ``last_usage()``. ``UsageTrackingLLM`` is the default + implementation, but custom LLM wrappers can be used if they follow the same + duck-typed interface. + """ + + def __init__( + self, + base_guide: Guide, + token_llm: Any, + *, + tokens_in_key: str = "tokens_in", + tokens_out_key: str = "tokens_out", + require_usage: bool = True, + reset_after_read: bool = True, + ) -> None: + if not isinstance(base_guide, Guide): + raise TypeError("base_guide must be a Guide") + if not callable(getattr(token_llm, "last_usage", None)): + raise TypeError("token_llm must expose a callable last_usage() method") + if not tokens_in_key or not tokens_out_key: + raise ValueError("token metric keys must be non-empty strings") + self._base = base_guide + self._token_llm = token_llm + self.tokens_in_key = tokens_in_key + self.tokens_out_key = tokens_out_key + self.require_usage = require_usage + self.reset_after_read = reset_after_read + + def __deepcopy__(self, memo: Dict[int, Any]) -> "TokenUsageAugmentingGuide": + cls = type(self) + new = cls.__new__(cls) + memo[id(self)] = new + new._base = copy.deepcopy(self._base, memo) + new._token_llm = self._token_llm + new.tokens_in_key = self.tokens_in_key + new.tokens_out_key = self.tokens_out_key + new.require_usage = self.require_usage + new.reset_after_read = self.reset_after_read + return new + + def get_feedback( + self, + query: str, + response: str, + reference: Optional[str] = None, + **kwargs: Any, + ) -> Tuple[float, str]: + reward, feedback = self._base.get_feedback( + query, + response, + reference=reference, + **kwargs, + ) + usage = self._usage_metrics() + return ( + float(reward), + ( + f"{feedback} " + f"{self.tokens_in_key}={usage[self.tokens_in_key]:.0f} " + f"{self.tokens_out_key}={usage[self.tokens_out_key]:.0f}" + ), + ) + + def get_score_dict( + self, + query: str, + response: str, + reference: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, float]: + score_dict = dict( + self._base.get_score_dict( + query, + response, + reference=reference, + **kwargs, + ) + ) + usage = self._usage_metrics() + collisions = sorted(set(score_dict).intersection(usage)) + if collisions: + raise ValueError( + "Base guide already emitted token metric keys: " + f"{collisions}. Use custom tokens_in_key/tokens_out_key." + ) + score_dict.update(usage) + return score_dict + + def _usage_metrics(self) -> Dict[str, float]: + has_usage = getattr(self._token_llm, "has_usage", None) + if self.require_usage and callable(has_usage) and not has_usage(): + raise RuntimeError( + "No token usage was recorded before reading guide metrics. " + "Ensure the evaluated agent uses the same UsageTrackingLLM " + "instance passed to TokenUsageAugmentingGuide." + ) + + try: + usage = self._token_llm.last_usage(reset=self.reset_after_read) + except TypeError: + usage = self._token_llm.last_usage() + reset_usage = getattr(self._token_llm, "reset_usage", None) + if self.reset_after_read and callable(reset_usage): + reset_usage() + + return { + self.tokens_in_key: float( + usage.get("tokens_in", usage.get(self.tokens_in_key, 0)) + ), + self.tokens_out_key: float( + usage.get("tokens_out", usage.get(self.tokens_out_key, 0)) + ), + } + class LLMJudge(Guide): """ diff --git a/opto/trainer/objectives.py b/opto/trainer/objectives.py index 5d8e6bd4..ad2a601e 100644 --- a/opto/trainer/objectives.py +++ b/opto/trainer/objectives.py @@ -40,6 +40,8 @@ class ObjectiveConfig: - "weighted": fall back to weighted scalarization. - "lexicographic": sort by metric names alphabetically. - "random_seeded": seeded random shuffle. + required_metrics: Metrics that must be present in every dict score. + Empty frozenset disables this check. seed: Random seed for deterministic tie-breaking. scalarize_dict: How to reduce dict scores to a scalar (when mode="scalar"). @@ -54,13 +56,18 @@ class ObjectiveConfig: missing_value: float = float("-inf") pareto_metrics: Optional[Tuple[str, ...]] = None tie_break: str = "weighted" + required_metrics: frozenset = field(default_factory=frozenset) seed: int = 0 scalarize_dict: str = "score" score_key: str = "score" - def __post_init__(self): + def __post_init__(self) -> None: if isinstance(self.minimize, set): object.__setattr__(self, 'minimize', frozenset(self.minimize)) + if isinstance(self.required_metrics, set): + object.__setattr__( + self, 'required_metrics', frozenset(self.required_metrics) + ) if self.mode not in ("scalar", "weighted", "pareto"): raise ValueError( f"mode must be 'scalar', 'weighted', or 'pareto', got '{self.mode}'" @@ -84,6 +91,9 @@ def __post_init__(self): raise ValueError( "pareto_metrics must be None (auto) or non-empty tuple" ) + for metric in self.required_metrics: + if not isinstance(metric, str) or not metric: + raise ValueError("required_metrics must contain only non-empty strings") # --------------------------------------------------------------------------- @@ -126,6 +136,20 @@ def to_score_dict(score: ScoreLike) -> Dict[str, float]: normalize_score = to_score_dict +def validate_required_metrics( + score_dict: Dict[str, float], config: ObjectiveConfig +) -> None: + """Fail fast when a score dict omits required objective metrics.""" + if not config.required_metrics: + return + missing = sorted(set(config.required_metrics) - set(score_dict)) + if missing: + raise ValueError( + "Missing required objective metrics: " + f"{missing}. Available metrics: {sorted(score_dict)}" + ) + + def score_dict_to_scalar(score_dict: Dict[str, float], config: ObjectiveConfig) -> float: """Reduce a score dict to a scalar according to ObjectiveConfig. @@ -138,6 +162,7 @@ def score_dict_to_scalar(score_dict: Dict[str, float], This exists to avoid hard-coding any dict->scalar behavior in Guide/Evaluator. """ sd = to_score_dict(score_dict) + validate_required_metrics(sd, config) sd = apply_minimize(sd, config.minimize) if config.scalarize_dict == "score": @@ -173,6 +198,11 @@ def to_scalar_score(score: ScoreLike, "to define reduction." ) return score_dict_to_scalar(score, config) + if config is not None and config.required_metrics: + raise ValueError( + "ObjectiveConfig.required_metrics requires dict scores; " + f"got scalar score {score!r}" + ) return float(score) @@ -254,6 +284,23 @@ def pareto_rank(candidates: List[Dict[str, float]], return ranks +def _prepare_score_dicts( + candidates: List[Tuple[ScoreLike, Any]], config: ObjectiveConfig +) -> List[Dict[str, float]]: + """Convert, validate, and higher-is-better-normalize candidate scores.""" + score_dicts: List[Dict[str, float]] = [] + for score, _ in candidates: + if config.required_metrics and not isinstance(score, dict): + raise ValueError( + "ObjectiveConfig.required_metrics requires dict scores; " + f"got scalar score {score!r}" + ) + score_dict = to_score_dict(score) + validate_required_metrics(score_dict, config) + score_dicts.append(score_dict) + return [apply_minimize(score_dict, config.minimize) for score_dict in score_dicts] + + def select_best(candidates: List[Tuple[ScoreLike, Any]], config: Optional[ObjectiveConfig] = None) -> int: """Select index of the single best candidate. @@ -275,8 +322,7 @@ def select_best(candidates: List[Tuple[ScoreLike, Any]], scores = [to_scalar_score(score, config) for score, _ in candidates] return int(np.argmax(scores)) - score_dicts = [to_score_dict(s) for s, _ in candidates] - score_dicts = [apply_minimize(sd, config.minimize) for sd in score_dicts] + score_dicts = _prepare_score_dicts(candidates, config) if config.mode == "weighted": weighted = [ @@ -337,8 +383,7 @@ def select_top_k(candidates: List[Tuple[ScoreLike, Any]], scores = [to_scalar_score(score, config) for score, _ in candidates] return list(np.argsort(scores)[::-1][:k]) - score_dicts = [to_score_dict(s) for s, _ in candidates] - score_dicts = [apply_minimize(sd, config.minimize) for sd in score_dicts] + score_dicts = _prepare_score_dicts(candidates, config) if config.mode == "weighted": weighted = [ diff --git a/tests/unit_tests/test_objectives_required_metrics.py b/tests/unit_tests/test_objectives_required_metrics.py new file mode 100644 index 00000000..1bc1307a --- /dev/null +++ b/tests/unit_tests/test_objectives_required_metrics.py @@ -0,0 +1,74 @@ +from typing import Dict + +import pytest + +from opto.trainer.objectives import ObjectiveConfig, select_best, select_top_k + + +def _token_objective_config() -> ObjectiveConfig: + """Return a reusable objective that minimizes error and token usage.""" + return ObjectiveConfig( + mode="weighted", + weights={"error": 1.0, "tokens_in": 1e-3, "tokens_out": 1e-3}, + minimize=frozenset({"error", "tokens_in", "tokens_out"}), + required_metrics=frozenset({"error", "tokens_in", "tokens_out"}), + ) + + +def test_weighted_objective_can_minimize_token_metrics() -> None: + config = _token_objective_config() + candidates = [ + ({"error": 0.0, "tokens_in": 100.0, "tokens_out": 100.0}, "long"), + ({"error": 0.0, "tokens_in": 10.0, "tokens_out": 10.0}, "short"), + ] + + assert select_best(candidates, config) == 1 + + +def test_top_k_validates_required_metrics() -> None: + config = _token_objective_config() + candidates = [ + ({"error": 0.0, "tokens_in": 10.0, "tokens_out": 10.0}, "ok"), + ({"error": 0.0, "tokens_in": 10.0}, "missing-tokens-out"), + ] + + with pytest.raises(ValueError, match="Missing required objective metrics"): + select_top_k(candidates, config, k=1) + + +def test_required_metrics_rejects_missing_token_metrics() -> None: + config = _token_objective_config() + candidates = [ + ({"error": 0.0, "tokens_in": 10.0}, "missing-tokens-out"), + ] + + with pytest.raises(ValueError, match="Missing required objective metrics"): + select_best(candidates, config) + + +def test_required_metrics_rejects_scalar_scores() -> None: + config = _token_objective_config() + + with pytest.raises(ValueError, match="requires dict scores"): + select_best([(1.0, "scalar-score")], config) + + +def test_required_metrics_accepts_sets_and_rejects_bad_names() -> None: + config = ObjectiveConfig(required_metrics={"score"}) + + assert config.required_metrics == frozenset({"score"}) + + with pytest.raises(ValueError, match="required_metrics"): + ObjectiveConfig(required_metrics={""}) + + +def test_required_metrics_with_scalar_mode_validates_dict_scores() -> None: + config = ObjectiveConfig( + mode="scalar", + scalarize_dict="score", + required_metrics=frozenset({"score", "tokens_in"}), + ) + score: Dict[str, float] = {"score": 1.0} + + with pytest.raises(ValueError, match="Missing required objective metrics"): + select_best([(score, "missing-token-count")], config) diff --git a/tests/unit_tests/test_token_usage_guide.py b/tests/unit_tests/test_token_usage_guide.py new file mode 100644 index 00000000..07d1f2d5 --- /dev/null +++ b/tests/unit_tests/test_token_usage_guide.py @@ -0,0 +1,206 @@ +import copy +from typing import Any, Dict, Optional, Tuple + +import pytest + +from opto.trainer.guide import Guide, TokenUsageAugmentingGuide, UsageTrackingLLM + + +class _Usage: + def __init__( + self, + prompt_tokens: Optional[int] = None, + completion_tokens: Optional[int] = None, + ) -> None: + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + +class _Message: + def __init__(self, content: str) -> None: + self.content = content + + +class _Choice: + def __init__(self, content: str) -> None: + self.message = _Message(content) + + +class _Response: + def __init__(self, content: str = "ok", usage: Optional[_Usage] = None) -> None: + self.choices = [_Choice(content)] + if usage is not None: + self.usage = usage + + +class _BaseGuide(Guide): + def get_feedback( + self, + query: str, + response: str, + reference: Optional[str] = None, + **kwargs: Any, + ) -> Tuple[float, str]: + return 1.0, "base-feedback" + + def get_score_dict( + self, + query: str, + response: str, + reference: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, float]: + return {"error": 0.0} + + +class _TokenCollisionGuide(_BaseGuide): + def get_score_dict( + self, + query: str, + response: str, + reference: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, float]: + return {"error": 0.0, "tokens_in": 1.0} + + +def test_usage_tracking_reads_openai_style_usage() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response( + content="two words", + usage=_Usage(prompt_tokens=7, completion_tokens=2), + ), + estimate_missing=False, + ) + + tracker(messages=[{"role": "user", "content": "ignored"}]) + + assert tracker.last_usage() == {"tokens_in": 7, "tokens_out": 2} + assert tracker.has_usage() + assert tracker.last_usage_was_estimated() is False + + +def test_usage_tracking_accepts_dict_usage_aliases() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: { + "choices": [{"message": {"content": "answer text"}}], + "usage": {"input_tokens": 9, "output_tokens": 3}, + }, + estimate_missing=False, + ) + + tracker(messages=[{"role": "user", "content": "ignored"}]) + + assert tracker.last_usage() == {"tokens_in": 9, "tokens_out": 3} + + +def test_usage_tracking_estimates_missing_usage_when_allowed() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response(content="two words"), + estimate_missing=True, + ) + + tracker(messages=[{"role": "user", "content": "one two three"}]) + + assert tracker.last_usage() == {"tokens_in": 3, "tokens_out": 2} + assert tracker.last_usage_was_estimated() is True + + +def test_usage_tracking_strict_mode_rejects_missing_usage() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response(content="two words"), + estimate_missing=False, + ) + + with pytest.raises(ValueError, match="did not include complete token usage"): + tracker(messages=[{"role": "user", "content": "one two three"}]) + + +def test_token_usage_augmenting_guide_requires_tracker_interface() -> None: + with pytest.raises(TypeError, match="last_usage"): + TokenUsageAugmentingGuide(_BaseGuide(), object()) + + +def test_token_usage_augmenting_guide_adds_metrics_and_resets() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response( + content="ok", + usage=_Usage(prompt_tokens=5, completion_tokens=1), + ) + ) + guide = TokenUsageAugmentingGuide(_BaseGuide(), tracker) + + tracker(messages=[{"role": "user", "content": "ignored"}]) + score_dict = guide.get_score_dict("q", "r", "ref") + + assert score_dict == {"error": 0.0, "tokens_in": 5.0, "tokens_out": 1.0} + assert tracker.has_usage() is False + + +def test_token_usage_augmenting_guide_feedback_includes_metrics() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response( + content="ok", + usage=_Usage(prompt_tokens=4, completion_tokens=2), + ) + ) + guide = TokenUsageAugmentingGuide(_BaseGuide(), tracker) + + tracker(messages=[{"role": "user", "content": "ignored"}]) + reward, feedback = guide.get_feedback("q", "r", "ref") + + assert reward == 1.0 + assert "base-feedback" in feedback + assert "tokens_in=4" in feedback + assert "tokens_out=2" in feedback + + +def test_token_usage_augmenting_guide_fails_when_no_usage_was_recorded() -> None: + tracker = UsageTrackingLLM(lambda **kwargs: _Response()) + guide = TokenUsageAugmentingGuide(_BaseGuide(), tracker) + + with pytest.raises(RuntimeError, match="No token usage was recorded"): + guide.get_score_dict("q", "r", "ref") + + +def test_token_usage_augmenting_guide_rejects_metric_collisions() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response( + content="ok", + usage=_Usage(prompt_tokens=5, completion_tokens=1), + ) + ) + guide = TokenUsageAugmentingGuide(_TokenCollisionGuide(), tracker) + + tracker(messages=[{"role": "user", "content": "ignored"}]) + + with pytest.raises(ValueError, match="already emitted token metric keys"): + guide.get_score_dict("q", "r", "ref") + + +def test_token_usage_survives_independent_agent_and_guide_deepcopy() -> None: + tracker = UsageTrackingLLM( + lambda **kwargs: _Response( + content="ok", + usage=_Usage(prompt_tokens=11, completion_tokens=4), + ) + ) + + class _Agent: + def __init__(self, llm: UsageTrackingLLM) -> None: + self.llm = llm + + def __call__(self, x: str) -> _Response: + return self.llm(messages=[{"role": "user", "content": x}]) + + agent = _Agent(tracker) + guide = TokenUsageAugmentingGuide(_BaseGuide(), tracker) + + agent_copy = copy.deepcopy(agent) + guide_copy = copy.deepcopy(guide) + + agent_copy("question") + score_dict = guide_copy.get_score_dict("q", "r", "ref") + + assert score_dict["tokens_in"] == 11.0 + assert score_dict["tokens_out"] == 4.0 From b098b95a9e6f091d9a819680040f955b0fc2a61f Mon Sep 17 00:00:00 2001 From: doxav <> Date: Wed, 10 Jun 2026 22:10:29 +0200 Subject: [PATCH 2/2] addressed ChingAn review --- opto/trainer/guide.py | 33 ++++++++++++++++++---- opto/trainer/objectives.py | 8 +++--- tests/unit_tests/test_token_usage_guide.py | 10 ++++++- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/opto/trainer/guide.py b/opto/trainer/guide.py index 30a9eef9..c5261dfa 100644 --- a/opto/trainer/guide.py +++ b/opto/trainer/guide.py @@ -5,6 +5,7 @@ import re import copy import os +import warnings from opto.utils.llm import LLM, AbstractModel from opto.trainer.suggest import Suggest @@ -107,10 +108,11 @@ class UsageTrackingLLM: Dict-shaped responses and ``input_tokens`` / ``output_tokens`` aliases are also supported. - If usage metadata is missing, the wrapper estimates token counts with a - simple whitespace split by default. Set ``estimate_missing=False`` to fail - fast instead. Streaming usage is only captured when the provider returns - usage on the response object passed back from the wrapped LLM. + If usage metadata is missing, the wrapper warns once and estimates token + counts with a simple whitespace split by default. Set + ``estimate_missing=False`` to fail fast instead. Streaming usage is only + captured when the provider returns usage on the response object passed back + from the wrapped LLM. The usage store is ``ContextVar``-backed, so concurrent evaluations do not overwrite each other. ``__deepcopy__`` returns ``self`` intentionally: @@ -129,6 +131,7 @@ def __init__(self, base_llm: Any, *, estimate_missing: bool = True) -> None: self._usage_estimated: contextvars.ContextVar[bool] = contextvars.ContextVar( "llm_token_usage_estimated", default=False ) + self._warned_missing_usage = False def __deepcopy__(self, memo: Dict[int, Any]) -> "UsageTrackingLLM": memo[id(self)] = self @@ -196,6 +199,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self._usage_estimated = contextvars.ContextVar( "llm_token_usage_estimated", default=False ) + self._warned_missing_usage = bool( + getattr(self, "_warned_missing_usage", False) + ) def _add_usage(self, *, tokens_in: int, tokens_out: int) -> None: previous = self._usage.get() or {"tokens_in": 0, "tokens_out": 0} @@ -224,6 +230,7 @@ def _read_token_usage( "response.usage." ) estimated = True + self._warn_missing_usage_once() if prompt_tokens is None: prompt_tokens = self._estimate_prompt_tokens(messages) if completion_tokens is None: @@ -231,6 +238,21 @@ def _read_token_usage( return int(prompt_tokens or 0), int(completion_tokens or 0), estimated + def _warn_missing_usage_once(self) -> None: + """Warn once when the wrapped backend cannot provide exact token usage.""" + if self._warned_missing_usage: + return + warnings.warn( + "The wrapped LLM backend returned a response without complete token " + "usage; UsageTrackingLLM is estimating token counts with whitespace " + "splitting. For exact token-aware objectives, use a backend that " + "returns OpenAI-compatible response.usage or set " + "estimate_missing=False to fail fast.", + UserWarning, + stacklevel=3, + ) + self._warned_missing_usage = True + @classmethod def _extract_usage(cls, response: Any) -> Tuple[Optional[int], Optional[int]]: usage = cls._get_field(response, "usage") @@ -322,7 +344,8 @@ class TokenUsageAugmentingGuide(Guide): The tracker must expose ``last_usage()``. ``UsageTrackingLLM`` is the default implementation, but custom LLM wrappers can be used if they follow the same - duck-typed interface. + duck-typed interface. This stays composition-based so existing guides can be + augmented without inheritance changes. """ def __init__( diff --git a/opto/trainer/objectives.py b/opto/trainer/objectives.py index ad2a601e..8fa87e85 100644 --- a/opto/trainer/objectives.py +++ b/opto/trainer/objectives.py @@ -284,10 +284,10 @@ def pareto_rank(candidates: List[Dict[str, float]], return ranks -def _prepare_score_dicts( +def _prepare_score_dicts_for_objective_selection( candidates: List[Tuple[ScoreLike, Any]], config: ObjectiveConfig ) -> List[Dict[str, float]]: - """Convert, validate, and higher-is-better-normalize candidate scores.""" + """Prepare candidate scores so objective selection can maximize all metrics.""" score_dicts: List[Dict[str, float]] = [] for score, _ in candidates: if config.required_metrics and not isinstance(score, dict): @@ -322,7 +322,7 @@ def select_best(candidates: List[Tuple[ScoreLike, Any]], scores = [to_scalar_score(score, config) for score, _ in candidates] return int(np.argmax(scores)) - score_dicts = _prepare_score_dicts(candidates, config) + score_dicts = _prepare_score_dicts_for_objective_selection(candidates, config) if config.mode == "weighted": weighted = [ @@ -383,7 +383,7 @@ def select_top_k(candidates: List[Tuple[ScoreLike, Any]], scores = [to_scalar_score(score, config) for score, _ in candidates] return list(np.argsort(scores)[::-1][:k]) - score_dicts = _prepare_score_dicts(candidates, config) + score_dicts = _prepare_score_dicts_for_objective_selection(candidates, config) if config.mode == "weighted": weighted = [ diff --git a/tests/unit_tests/test_token_usage_guide.py b/tests/unit_tests/test_token_usage_guide.py index 07d1f2d5..91e9d536 100644 --- a/tests/unit_tests/test_token_usage_guide.py +++ b/tests/unit_tests/test_token_usage_guide.py @@ -1,4 +1,5 @@ import copy +import warnings from typing import Any, Dict, Optional, Tuple import pytest @@ -100,11 +101,18 @@ def test_usage_tracking_estimates_missing_usage_when_allowed() -> None: estimate_missing=True, ) - tracker(messages=[{"role": "user", "content": "one two three"}]) + with pytest.warns(UserWarning, match="estimating token counts"): + tracker(messages=[{"role": "user", "content": "one two three"}]) assert tracker.last_usage() == {"tokens_in": 3, "tokens_out": 2} assert tracker.last_usage_was_estimated() is True + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + tracker(messages=[{"role": "user", "content": "one two three"}]) + + assert captured == [] + def test_usage_tracking_strict_mode_rejects_missing_usage() -> None: tracker = UsageTrackingLLM(