Skip to content

POLCA + Trace trainer#16

Open
allenanie wants to merge 9 commits into
mainfrom
dspy-trainer
Open

POLCA + Trace trainer#16
allenanie wants to merge 9 commits into
mainfrom
dspy-trainer

Conversation

@allenanie

Copy link
Copy Markdown
Member

Xuanfei's new change to make POLCA runnable

all_tasks = _load_hf_data(cfg, train_split, n_load)
train_tasks = all_tasks[:num_train]
val_tasks = all_tasks[:num_validate]
test_tasks = all_tasks[:num_test] if num_test > 0 else []

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, num_train, num_validate, num_test all overlap each other? I see.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest this generic fix:

diff --git a/benchmarks/hf_qa/hf_qa_loader.py b/benchmarks/hf_qa/hf_qa_loader.py
index be2f57e..d8f6f0a 100644
--- a/benchmarks/hf_qa/hf_qa_loader.py
+++ b/benchmarks/hf_qa/hf_qa_loader.py
@@ -41,7 +41,7 @@ from __future__ import annotations
 
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 import yaml
 
@@ -186,6 +186,52 @@ def _load_hf_data(cfg: Dict[str, Any], split: str, n: int) -> List[HFQATask]:
     return tasks
 
 
+def _resolve_count(
+    name: str,
+    override: Optional[int],
+    cfg: Dict[str, Any],
+    default: int,
+) -> int:
+    """Resolve a dataset partition size and validate it."""
+    value = override if override is not None else cfg.get(name, default)
+    if not isinstance(value, int) or isinstance(value, bool):
+        raise TypeError(f"{name} must be a non-negative integer, got {value!r}")
+    if value < 0:
+        raise ValueError(f"{name} must be non-negative, got {value}")
+    return value
+
+
+def _load_hf_partitions(
+    cfg: Dict[str, Any],
+    specs: List[Tuple[str, str, int]],
+) -> Dict[str, List[HFQATask]]:
+    """Load named HF dataset partitions without index overlap.
+
+    ``specs`` is a list of ``(partition_name, hf_split, count)`` tuples.
+
+    Partitions that use the same HF split are allocated contiguous, disjoint
+    slices in the order they appear in ``specs``.  Partitions that use different
+    HF splits each start from the beginning of their own split.
+
+    Example:
+        [
+            ("train", "train", 100),
+            ("validate", "train", 100),
+            ("test", "train", 100),
+        ]
+
+    loads 300 examples from HF split ``train`` once, then slices them as
+    ``[0:100]``, ``[100:200]``, and ``[200:300]``.
+    """
+    partitions: Dict[str, List[HFQATask]] = {name: [] for name, _, _ in specs}
+    by_split: Dict[str, List[Tuple[str, int]]] = {}
+
+    for name, split, count in specs:
+        if not isinstance(count, int) or isinstance(count, bool):
+            raise TypeError(f"Partition {name!r} count must be an integer, got {count!r}")
+        if count < 0:
+            raise ValueError(f"Partition {name!r} count must be non-negative, got {count}")
+        by_split.setdefault(split, []).append((name, count))
+
+    for split, split_specs in by_split.items():
+        total = sum(count for _, count in split_specs)
+        loaded = _load_hf_data(cfg, split, total) if total > 0 else []
+        start = 0
+        for name, count in split_specs:
+            partitions[name] = loaded[start:start + count]
+            start += count
+
+    return partitions
+
+
 # ---------------------------------------------------------------------------
 # Internal helpers
 # ---------------------------------------------------------------------------
@@ -249,29 +295,43 @@ def build_trace_problem(
     cfg = _load_task_config(task_id)
 
     # Resolve dataset sizes: eval_kwargs > hf_tasks.yaml > hardcoded fallback.
-    num_train    = num_train    if num_train    is not None else cfg.get("num_train",    10)
-    num_validate = num_validate if num_validate is not None else cfg.get("num_validate",  0)
-    num_test     = num_test     if num_test     is not None else cfg.get("num_test",      0)
+    num_train = _resolve_count("num_train", num_train, cfg, 10)
+    num_validate = _resolve_count("num_validate", num_validate, cfg, 0)
+    num_test = _resolve_count("num_test", num_test, cfg, 0)
 
     if subtask:
         # BBEH-style multi-split tasks: every subset (train/val/test) comes from
         # the same split, which is the subtask name (e.g. "boolean_expressions").
         split = subtask
-        total = num_train + num_validate + num_test
-        all_tasks = _load_hf_data(cfg, split, total)
-        train_tasks = all_tasks[:num_train]
-        val_tasks   = all_tasks[num_train:num_train + num_validate]
-        test_tasks  = all_tasks[num_train + num_validate:]
+        train_split_name = split
+        validate_split_name = split
+        test_split_name = split
+        partitions = _load_hf_partitions(
+            cfg,
+            [
+                ("train", split, num_train),
+                ("validate", split, num_validate),
+                ("test", split, num_test),
+            ],
+        )
     else:
         # Tasks with explicit train/test HF splits (e.g. HotpotQA):
         #   train + val  → loaded from `train_split`  (default: "train")
         #   test         → loaded from `test_split`   (default: same as train_split)
+        # If train/val/test share the same HF split, they are disjoint
+        # contiguous slices, not independent "first n examples" loads.
         # Falls back to the legacy `split` field if neither is set.
         fallback = cfg.get("split", "train")
         train_split = cfg.get("train_split", fallback)
         test_split  = cfg.get("test_split",  train_split)
+        train_split_name = train_split
+        validate_split_name = train_split
+        test_split_name = test_split
+        partitions = _load_hf_partitions(
+            cfg,
+            [
+                ("train", train_split, num_train),
+                ("validate", train_split, num_validate),
+                ("test", test_split, num_test),
+            ],
+        )
 
-        train_val_tasks = _load_hf_data(cfg, train_split, num_train + num_validate)
-        train_tasks = train_val_tasks[:num_train]
-        val_tasks   = train_val_tasks[num_train:]
-        test_tasks  = _load_hf_data(cfg, test_split, num_test) if num_test > 0 else []
+    train_tasks = partitions["train"]
+    val_tasks = partitions["validate"]
+    test_tasks = partitions["test"]
 
     objective = cfg.get("objective", "Optimize the agent's instructions for this QA task.")
     agent_class = cfg.get("agent_class", "hfqa")
@@ -307,6 +367,9 @@ def build_trace_problem(
             num_train=num_train,
             num_validate=num_validate,
             num_test=num_test,
+            train_split=train_split_name,
+            validate_split=validate_split_name,
+            test_split=test_split_name,
             framework=framework,
         ),
     )
diff --git a/benchmarks/hf_qa/hf_tasks.yaml b/benchmarks/hf_qa/hf_tasks.yaml
index e8c549e..2d5fae1 100644
--- a/benchmarks/hf_qa/hf_tasks.yaml
+++ b/benchmarks/hf_qa/hf_tasks.yaml
@@ -22,8 +22,9 @@
 # Optional fields:
 #   dataset_config  : HuggingFace config name (e.g. "distractor", "plain_text")
 #   train_split     : HF split used for train + val examples (default: "train")
-#   test_split      : HF split used for test examples (default: same as train_split)
-#   split           : fallback split when train_split/test_split are not set
+#   test_split      : HF split used for test examples (default: same as train_split;
+#                     when the same split is reused, partitions are disjoint slices)
+#   split           : fallback split when train_split/test_split are not set
 #   description     : short human-readable description
 #   objective       : optimizer objective injected into optimizer_kwargs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenanie @xuanfeiren it would be cool if we can fix this then merge => once merged, I would like to import some more HF QA datasets

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait what exactly is this fixing?

For POLCA setup, we actually need to have train/val/test being all the same with each other!

So I guess the hf_tasks need to have the flexibility to do both (separate train/valid/test and overlapping train/valid/test).

@doxav doxav Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OMG :-), looks strange, I should look how POLCA works.

I checked latest main, and the current default HF loader does not make train/val/test identical: train and val are sequential slices, and test is loaded from test_split; for HF QA, main uses train_split: train and test_split: validation.

So I think the right fix is to support both modes explicitly:

default / benchmark mode: keep the current main behavior, with disjoint train/val and separate test split;
POLCA mode: opt into same_examples / shared_split, where train/val/test are intentionally the same examples.

We should add a config flag for this instead of hardcoding the POLCA behavior globally ? That should satisfy POLCA while avoiding surprising data leakage for normal HF QA benchmark runs.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- I think that's probably the best solution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants