From 331783a8a36a8a9adb44c9a935f6c5020e518b26 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 00:24:41 +0000 Subject: [PATCH 1/6] docs(itl-432): add design spec for pydantic/dataclass pipeline column fixes Co-Authored-By: Claude Sonnet 4.6 --- .zed/rules | 3 +- CLAUDE.md | 2 +- ...antic-dataclass-pipeline-columns-design.md | 215 ++++++++++++++++++ 3 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md diff --git a/.zed/rules b/.zed/rules index b12d03e2..06b3eb32 100644 --- a/.zed/rules +++ b/.zed/rules @@ -101,8 +101,7 @@ Remove any optional sections that don't apply rather than leaving them empty. When working on a feature, create and checkout a git branch using the gitBranchName returned by the primary Linear issue (e.g. eywalker/plt-911-add-documentation-for-orcapod-python). -Feature branch PRs always target the "dev" branch. The dev → main PR is used -for versioning/releases only. +Feature branch PRs always target "main". Create a feature branch from "main" and open PRs against "main". If a feature branch / PR corresponds to multiple Linear issues, list all of them in the PR description body so that Linear's GitHub integration auto-tracks the PR against each diff --git a/CLAUDE.md b/CLAUDE.md index 0ec257b8..624086de 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -107,7 +107,7 @@ Remove any optional sections that don't apply rather than leaving them empty. When working on a feature, create and checkout a git branch using the `gitBranchName` returned by the primary Linear issue (e.g. `eywalker/plt-911-add-documentation-for-orcapod-python`). -**Feature branch PRs always target the `extension-type-system` branch.** The `extension-type-system` → `dev` → `main` PRs are used for integration and releases. +**Feature branch PRs always target `main`.** Create a feature branch from `main` and open PRs against `main`. If a feature branch / PR corresponds to multiple Linear issues, list all of them in the PR description body so that Linear's GitHub integration auto-tracks the PR against each diff --git a/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md b/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md new file mode 100644 index 00000000..6037e527 --- /dev/null +++ b/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md @@ -0,0 +1,215 @@ +# ITL-432: Pydantic/Dataclass Models as Pipeline Columns + +**Issue:** ITL-432 +**Date:** 2026-06-27 + +## Overview + +Pydantic models and dataclasses cannot flow through Orcapod pipelines as column +types. Two independent defects in the Arrow/Polars extension type machinery are +responsible. Both bugs are self-contained and addressed by surgical changes to +the hashing handler registry and Polars extension type construction. + +--- + +## Bug A — Extension type reaching `ArrowDigester` + +### Symptom + +``` +TypeError: unhashable type: '_ArrowExt___main___Cfg' +``` + +### Root cause + +No semantic handlers are registered for `pydantic.BaseModel` subclasses or +dataclasses. In `SemanticHashingVisitor.visit_extension()`, when +`type_handler_registry.has_handler(python_type)` returns False, the visitor +returns the extension type and its storage value unchanged (the passthrough +path). That extension-typed column flows directly into `ArrowDigester.hash_table`. + +`ArrowDigester` is a pure-Python starfix implementation. Its +`_primitive_data_type_string` function builds a lookup dict with Arrow +primitive type instances as keys (`if dt in _simple:`). Performing a dict +membership test on an unhashable extension type raises `TypeError`. Even if +`__hash__` were added to the extension type, `ArrowDigester` would raise +`NotImplementedError` because it has no handler for extension types — the +only correct fix is ensuring extension types never reach it. + +### Fix: register semantic handlers for pydantic models and dataclasses + +**`PythonTypeHandlerProtocol` — add optional `supports_type`** + +Add an optional `supports_type(target_type: type) -> bool` method, +mirroring the `supports_class` pattern already used by +`LogicalTypeFactoryProtocol`. When defined, the handler registry calls it +after finding a handler via MRO walk; if it returns False the walk continues. +Handlers without `supports_type` are treated as unconditional matches +(existing behaviour unchanged). + +**`PythonTypeHandlerRegistry` — respect `supports_type` in MRO walk** + +Update `get_handler_for_type` to apply `supports_type` at every lookup point — +both the initial exact-match check and the MRO walk: + +```python +def _try_handler(handler, target_type): + """Return handler if it accepts target_type, else None.""" + if handler is None: + return None + if hasattr(handler, "supports_type") and not handler.supports_type(target_type): + return None + return handler + +# exact match +handler = _try_handler(self._handlers.get(target_type), target_type) +if handler is not None: + return handler +# MRO walk +for base in target_type.__mro__[1:]: + handler = _try_handler(self._handlers.get(base), target_type) + if handler is not None: + return handler +return None +``` + +`get_handler(obj)` delegates to `get_handler_for_type(type(obj))` and +`has_handler(target_type)` delegates to `get_handler_for_type`, so both +inherit the fix automatically. + +**New handlers in `builtin_handlers.py`** + +```python +class PydanticModelHandler: + """Handler for pydantic BaseModel instances — delegates to model_dump().""" + + def handle(self, obj: Any, hasher: SemanticHasherProtocol) -> Any: + return obj.model_dump() + + +class DataclassModelHandler: + """Handler for dataclass instances — delegates to dataclasses.asdict().""" + + def supports_type(self, target_type: type) -> bool: + import dataclasses + return dataclasses.is_dataclass(target_type) + + def handle(self, obj: Any, hasher: SemanticHasherProtocol) -> Any: + import dataclasses + return dataclasses.asdict(obj) +``` + +`model_dump()` and `dataclasses.asdict()` both return plain dicts that +accurately reflect the model's content. The recursive semantic hasher hashes +the returned dict, producing a stable content-based hash. + +**Registration in `register_builtin_python_type_handlers`** + +```python +from pydantic import BaseModel +registry.register(BaseModel, PydanticModelHandler()) + +import dataclasses as _dc +registry.register(object, DataclassModelHandler()) +``` + +`PydanticModelHandler` is registered against `pydantic.BaseModel` — MRO +lookup finds it for any subclass. `DataclassModelHandler` is registered +against `object` (matching the pattern used by `DataclassLogicalTypeFactory` +in `v0.1.json`) and gated by `supports_type`, which returns True only for +actual dataclass types. + +--- + +## Bug B — Metadata loss on Polars round-trip + +### Symptom + +``` +ValueError: Arrow extension type '__main__.Cfg': expected metadata +b'{"category": "orcapod.pydantic"}' but got b'' +``` + +### Root cause + +`PydanticLogicalType.__init__` (line 93) and `DataclassLogicalType.__init__` +(line 93) both call `make_polars_extension_type(logical_name, storage_type)` +without passing `metadata`. The Arrow extension type is built with category +metadata (`b'{"category": "orcapod.pydantic"}'`), but the Polars extension +type carries no metadata. + +When `pl.DataFrame(table).to_arrow()` reconstructs the Arrow column, Polars +calls `__arrow_ext_deserialize__` with the Polars extension's metadata — which +is empty bytes (`b''`). The strict equality check in `_deserialize` fails +because `b'' != b'{"category": "orcapod.pydantic"}'`. + +### Fix: pass category metadata to `make_polars_extension_type` + +In `PydanticLogicalType.__init__`: + +```python +# Before: +self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + +# After: +self._polars_ext_class = make_polars_extension_type( + logical_name, + storage_type, + metadata=json.dumps({"category": PYDANTIC_CATEGORY}), +) +``` + +Same change in `DataclassLogicalType.__init__`, using `DATACLASS_CATEGORY`. + +After this fix, `pl.DataFrame(table).to_arrow()` passes +`b'{"category": "orcapod.pydantic"}'` to `__arrow_ext_deserialize__`, which +matches `_metadata` and succeeds. + +--- + +## Files changed + +| File | Change | +|------|--------| +| `src/orcapod/protocols/hashing_protocols.py` | Add optional `supports_type` to `PythonTypeHandlerProtocol` docstring and protocol stub | +| `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` | Update `get_handler_for_type` to call `supports_type` when present; `get_handler` inherits the fix via delegation | +| `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` | Add `PydanticModelHandler`, `DataclassModelHandler`; register both in `register_builtin_python_type_handlers` | +| `src/orcapod/extension_types/pydantic_logical_type_factory.py` | Pass `metadata` to `make_polars_extension_type()` in `PydanticLogicalType.__init__` | +| `src/orcapod/extension_types/dataclass_logical_type_factory.py` | Pass `metadata` to `make_polars_extension_type()` in `DataclassLogicalType.__init__` | +| `tests/test_hashing/test_pydantic_dataclass_hashing.py` | New regression tests (see below) | + +--- + +## Tests + +New file: `tests/test_hashing/test_pydantic_dataclass_hashing.py` + +**Bug A regression — pydantic:** +Build a table with a pydantic model column (registered via the default +context), call `arrow_hasher.hash_table(table)`. Assert no `TypeError` is +raised and a `ContentHash` is returned. + +**Bug A regression — dataclass:** +Same as above with a dataclass column. + +**Bug B regression — pydantic Polars round-trip:** +Build a table with a pydantic model column, round-trip via +`pl.DataFrame(table).to_arrow()`, call `arrow_hasher.hash_table(round_tripped)`. +Assert no `ValueError` is raised and the hash equals that of the original table. + +**Bug B regression — dataclass Polars round-trip:** +Same as above with a dataclass column. + +**Handler unit tests:** +- `PydanticModelHandler.handle` returns `model.model_dump()` for a flat model +- `DataclassModelHandler.handle` returns `dataclasses.asdict(obj)` for a flat dataclass +- `DataclassModelHandler.supports_type` returns True for dataclasses, False for pydantic models and plain classes +- Registry MRO walk respects `supports_type`: registering `DataclassModelHandler` against `object` does not intercept non-dataclass lookups + +--- + +## Out of scope + +- Adding `__hash__` to synthesized extension types (tracked as a follow-up) +- Schema cleaner changes (no longer needed — the Polars metadata fix resolves the underlying cause) +- Deserialization relaxation (no backward-compatibility shims; greenfield project) From 284cce49be85e60def07a277ef75eaa769ad25b5 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 01:09:58 +0000 Subject: [PATCH 2/6] fix(hashing): pydantic/dataclass columns now flow through pipelines as Arrow extension types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug A: StarfixArrowHasher._process_table_columns left live pa.ExtensionType columns intact after the SemanticHashingVisitor passthrough. ArrowDigester crashed with TypeError (unhashable type) because extension types are used as dict keys in starfix's _primitive_data_type_string. Fix: add normalize_extension_columns() to arrow_utils.py — converts any top-level extension-typed column to IPC storage representation (storage type + ARROW:extension:* field metadata) using ExtensionArray.storage with no Python materialization. _process_table_columns calls it after the visitor loop so ArrowDigester never sees a live pa.ExtensionType. Bug B: PydanticLogicalType and DataclassLogicalType both called make_polars_extension_type() without the metadata= argument. The Polars extension type carried empty metadata, so pl.DataFrame(table).to_arrow() passed b'' to __arrow_ext_deserialize__, which expected the category bytes and raised ValueError. Fix: pass metadata=json.dumps({"category": ...}) to make_polars_extension_type() in both __init__ methods so the Polars type carries the same metadata string as the Arrow extension type. Also logs the efficiency follow-up (ITL-433) in DESIGN_ISSUES.md: the to_pylist() roundtrip is wasteful for passthrough extension columns and should be short-circuited in a future refactor. Closes ITL-432 Co-Authored-By: Claude Sonnet 4.6 --- DESIGN_ISSUES.md | 21 ++ .../dataclass_logical_type_factory.py | 6 +- .../pydantic_logical_type_factory.py | 6 +- src/orcapod/hashing/arrow_hashers.py | 21 +- src/orcapod/utils/arrow_utils.py | 69 +++++++ ...antic-dataclass-pipeline-columns-design.md | 195 ++++++++---------- .../test_pydantic_dataclass_hashing.py | 182 ++++++++++++++++ 7 files changed, 388 insertions(+), 112 deletions(-) create mode 100644 tests/test_hashing/test_pydantic_dataclass_hashing.py diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 8c8572c9..841b2c5d 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -853,6 +853,27 @@ Should be removed in the next breaking release. Consider adding deprecation warn --- +### H5 — `_process_table_columns` materializes extension columns to Python even with no handler +**Status:** open +**Severity:** medium +**Issue:** ITL-433 + +`StarfixArrowHasher._process_table_columns` calls `to_pylist()` on every extension-typed +column before running `SemanticHashingVisitor`. For columns whose Python type has no registered +semantic handler (pydantic models, dataclasses, any unhandled extension type), the visitor +returns the value unchanged and the data is immediately re-serialized back to Arrow. The Python +roundtrip serves no purpose in this case and is O(rows) deserialization work wasted. + +The fix is to short-circuit at the column level: before calling `to_pylist()`, check whether +`type_handler_registry.has_handler(python_type)` would be True for this column's extension +type. If not, call `normalize_extension_columns()` directly on the Arrow column (uses +`ExtensionArray.storage`, no Python materialization) and skip the visitor loop entirely. +Columns with a registered handler (e.g. `Path`) continue through the existing path unchanged. + +The `normalize_extension_columns` utility landed in ITL-432. + +--- + ## `src/orcapod/utils/` ### U1 — Source-info column type hard-coded to `large_string` diff --git a/src/orcapod/extension_types/dataclass_logical_type_factory.py b/src/orcapod/extension_types/dataclass_logical_type_factory.py index 5633ffd7..6964d882 100644 --- a/src/orcapod/extension_types/dataclass_logical_type_factory.py +++ b/src/orcapod/extension_types/dataclass_logical_type_factory.py @@ -90,7 +90,11 @@ def __init__( # top-level extension type from each field's Arrow type before inserting it into the # struct. On the read path, ``reconstruct_from_arrow`` receives a ``storage_type`` # already guaranteed storage-safe by ``register_storage_type``. - self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext_class = make_polars_extension_type( + logical_name, + storage_type, + metadata=json.dumps({"category": DATACLASS_CATEGORY}), + ) self._polars_ext: pl.BaseExtension | None = None @property diff --git a/src/orcapod/extension_types/pydantic_logical_type_factory.py b/src/orcapod/extension_types/pydantic_logical_type_factory.py index e40baf4f..63261587 100644 --- a/src/orcapod/extension_types/pydantic_logical_type_factory.py +++ b/src/orcapod/extension_types/pydantic_logical_type_factory.py @@ -90,7 +90,11 @@ def __init__( # top-level extension type from each field's Arrow type before inserting it into the # struct. On the read path, ``reconstruct_from_arrow`` receives a ``storage_type`` # already guaranteed storage-safe by ``register_storage_type``. - self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext_class = make_polars_extension_type( + logical_name, + storage_type, + metadata=json.dumps({"category": PYDANTIC_CATEGORY}), + ) self._polars_ext: pl.BaseExtension | None = None @property diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index d5ce6a7c..876e3f5a 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -7,6 +7,7 @@ from orcapod.hashing.schema_cleaner import clean_schema_for_hashing, has_extension_metadata from orcapod.hashing.visitors import SemanticHashingVisitor +from orcapod.utils.arrow_utils import normalize_extension_columns from orcapod.types import ContentHash if TYPE_CHECKING: @@ -57,7 +58,16 @@ def hasher_id(self) -> str: return self._hasher_id def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Table": - """Replace semantic-typed columns with their content-hash bytes.""" + """Replace semantic-typed columns with content-hash bytes; normalize extension columns. + + For columns whose Python type has a registered semantic handler (e.g. ``Path``), + the extension-typed column is replaced by a ``pa.large_binary()`` column of + content-hash tokens. For all other extension-typed columns (visitor passthrough), + the column is normalized to IPC storage representation via + ``normalize_extension_columns`` — storage type for the data, extension identity + in field metadata — so that ``ArrowDigester`` can hash them without encountering + a live ``pa.ExtensionType``, which is unhashable. + """ new_columns: list[pa.Array] = [] new_fields: list[pa.Field] = [] @@ -91,6 +101,7 @@ def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Tabl if new_type is None: new_type = field.type + new_columns.append(pa.array(processed_data, type=new_type)) new_fields.append(field.with_type(new_type)) @@ -99,10 +110,16 @@ def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Tabl f"Failed to process column '{field.name}': {exc}" ) from exc - return pa.table( + intermediate = pa.table( new_columns, schema=pa.schema(new_fields, metadata=table.schema.metadata), ) + # Normalize any remaining extension-typed columns to their IPC storage + # representation (storage type + ARROW:extension:* field metadata). + # This handles the visitor passthrough case — extension types with no + # registered semantic handler — so that ArrowDigester never receives a + # live pa.ExtensionType, which is unhashable and would crash starfix. + return normalize_extension_columns(intermediate) def hash_schema(self, schema: "pa.Schema") -> ContentHash: """Hash an Arrow schema using the starfix canonical algorithm.""" diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index df7ba4c4..fc87c373 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -285,6 +285,75 @@ def normalize_view_types(arrow_type: "pa.DataType") -> "pa.DataType": return arrow_type +def normalize_extension_columns(table: "pa.Table") -> "pa.Table": + """Return a copy of ``table`` with all extension-typed columns converted to + their IPC/Parquet storage representation. + + For each top-level column whose type is a ``pa.ExtensionType``, the column + data is replaced with the underlying storage array (via + ``ExtensionArray.storage`` — no Python-level materialization) and the field + gains ``ARROW:extension:name`` and ``ARROW:extension:metadata`` keys in its + metadata, exactly matching the on-disk Arrow IPC/Parquet format. + + Non-extension columns are returned unchanged. Schema-level metadata and + existing per-field metadata are preserved; the two ``ARROW:extension:*`` + keys are merged in (or added) without touching any other metadata already + on the field. + + This is a fast path: for tables with no extension columns the original + table object is returned immediately. For tables that do have extension + columns a new table is constructed; the column data itself is not copied — + ``ExtensionArray.storage`` returns a zero-copy view of the underlying + buffers. + + Note: only **top-level** extension columns are handled. Extension types + nested inside struct fields or list element types are not supported by the + orcapod type system (see ET1 in DESIGN_ISSUES.md) and are left unchanged. + + Args: + table: Input Arrow table, may contain extension-typed columns. + + Returns: + A ``pa.Table`` where every top-level extension-typed column has been + replaced by its storage-typed equivalent with extension identity + preserved in field metadata. + """ + if not any(isinstance(field.type, pa.ExtensionType) for field in table.schema): + return table + + new_columns = [] + new_fields = [] + for i, field in enumerate(table.schema): + if isinstance(field.type, pa.ExtensionType): + ext_type = field.type + # combine_chunks() returns a single ExtensionArray; .storage is a + # zero-copy view of the underlying buffers with the storage type. + storage_arr = table.column(i).combine_chunks().storage + serialized = ext_type.__arrow_ext_serialize__() + # Merge extension identity into existing field metadata (if any) + # so that non-extension keys already on the field are preserved. + existing_meta = dict(field.metadata) if field.metadata else {} + existing_meta[b"ARROW:extension:name"] = ( + ext_type.extension_name.encode("utf-8") + ) + existing_meta[b"ARROW:extension:metadata"] = serialized + new_fields.append(pa.field( + field.name, + ext_type.storage_type, + nullable=field.nullable, + metadata=existing_meta, + )) + new_columns.append(storage_arr) + else: + new_columns.append(table.column(i)) + new_fields.append(field) + + return pa.table( + new_columns, + schema=pa.schema(new_fields, metadata=table.schema.metadata), + ) + + def normalize_table_view_types(table: "pa.Table") -> "pa.Table": """Cast a table's view-typed columns to their large variants. diff --git a/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md b/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md index 6037e527..f024fd92 100644 --- a/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md +++ b/superpowers/specs/2026-06-27-itl-432-pydantic-dataclass-pipeline-columns-design.md @@ -6,9 +6,30 @@ ## Overview Pydantic models and dataclasses cannot flow through Orcapod pipelines as column -types. Two independent defects in the Arrow/Polars extension type machinery are -responsible. Both bugs are self-contained and addressed by surgical changes to -the hashing handler registry and Polars extension type construction. +types. Two independent defects are responsible: + +- **Bug A** — Extension-typed columns crash `ArrowDigester` because no + normalization step converts live `pa.ExtensionType` columns to the + storage-type-plus-Arrow-metadata representation that starfix expects. +- **Bug B** — `pl.DataFrame(table).to_arrow()` raises `ValueError` because + the synthesized Polars extension types carry no metadata, so + `__arrow_ext_deserialize__` receives `b''` instead of the expected category + bytes. + +Both bugs are self-contained and addressed by surgical changes to +`StarfixArrowHasher` and the two logical type `__init__` methods. + +--- + +## Why the extension type data is already fully captured + +Pydantic models and dataclasses are stored as Arrow extension types whose +storage type is a `pa.struct` of the model/dataclass fields, with each field +recursively resolved to an Arrow type. The entirety of the model's data content +is therefore already captured in the extension type's storage value. No +separate semantic handler is needed to hash pydantic/dataclass columns — the +Arrow hashing layer already handles the underlying struct data once the +extension type wrapper is stripped. --- @@ -22,102 +43,63 @@ TypeError: unhashable type: '_ArrowExt___main___Cfg' ### Root cause -No semantic handlers are registered for `pydantic.BaseModel` subclasses or -dataclasses. In `SemanticHashingVisitor.visit_extension()`, when -`type_handler_registry.has_handler(python_type)` returns False, the visitor -returns the extension type and its storage value unchanged (the passthrough -path). That extension-typed column flows directly into `ArrowDigester.hash_table`. +`SemanticHashingVisitor.visit_extension()` has a passthrough path: when no +semantic handler is registered for the resolved Python type, it returns the +extension type and storage value unchanged. The extension-typed column then +flows through `StarfixArrowHasher._process_table_columns` as-is and reaches +`ArrowDigester.hash_table`. -`ArrowDigester` is a pure-Python starfix implementation. Its -`_primitive_data_type_string` function builds a lookup dict with Arrow -primitive type instances as keys (`if dt in _simple:`). Performing a dict -membership test on an unhashable extension type raises `TypeError`. Even if -`__hash__` were added to the extension type, `ArrowDigester` would raise -`NotImplementedError` because it has no handler for extension types — the -only correct fix is ensuring extension types never reach it. +`ArrowDigester` has no `pa.types.is_extension()` branch. Extension types fall +through all type guards in `_data_type_to_value` and crash at `if dt in +_simple:` in `_primitive_data_type_string` because `pa.ExtensionType` +instances are not hashable (they override `__eq__` without `__hash__`). -### Fix: register semantic handlers for pydantic models and dataclasses - -**`PythonTypeHandlerProtocol` — add optional `supports_type`** - -Add an optional `supports_type(target_type: type) -> bool` method, -mirroring the `supports_class` pattern already used by -`LogicalTypeFactoryProtocol`. When defined, the handler registry calls it -after finding a handler via MRO walk; if it returns False the walk continues. -Handlers without `supports_type` are treated as unconditional matches -(existing behaviour unchanged). - -**`PythonTypeHandlerRegistry` — respect `supports_type` in MRO walk** - -Update `get_handler_for_type` to apply `supports_type` at every lookup point — -both the initial exact-match check and the MRO walk: - -```python -def _try_handler(handler, target_type): - """Return handler if it accepts target_type, else None.""" - if handler is None: - return None - if hasattr(handler, "supports_type") and not handler.supports_type(target_type): - return None - return handler - -# exact match -handler = _try_handler(self._handlers.get(target_type), target_type) -if handler is not None: - return handler -# MRO walk -for base in target_type.__mro__[1:]: - handler = _try_handler(self._handlers.get(base), target_type) - if handler is not None: - return handler -return None -``` +The correct representation for an extension-typed column when stored outside +Python memory (IPC/Parquet) is: the **storage type** for the data, plus +`ARROW:extension:name` and `ARROW:extension:metadata` in **field metadata**. +This is exactly what `ArrowDigester` knows how to process. -`get_handler(obj)` delegates to `get_handler_for_type(type(obj))` and -`has_handler(target_type)` delegates to `get_handler_for_type`, so both -inherit the fix automatically. +### Fix: normalize extension columns to storage type + metadata in `StarfixArrowHasher` -**New handlers in `builtin_handlers.py`** +After `SemanticHashingVisitor` processes each column, any column whose +resulting type is still a `pa.ExtensionType` was not handled by the visitor. +At that point, `StarfixArrowHasher._process_table_columns` normalizes it to +the IPC representation: storage type for the array, extension identity in +field metadata. ```python -class PydanticModelHandler: - """Handler for pydantic BaseModel instances — delegates to model_dump().""" - - def handle(self, obj: Any, hasher: SemanticHasherProtocol) -> Any: - return obj.model_dump() - - -class DataclassModelHandler: - """Handler for dataclass instances — delegates to dataclasses.asdict().""" - - def supports_type(self, target_type: type) -> bool: - import dataclasses - return dataclasses.is_dataclass(target_type) - - def handle(self, obj: Any, hasher: SemanticHasherProtocol) -> Any: - import dataclasses - return dataclasses.asdict(obj) +if isinstance(new_type, pa.ExtensionType): + # Extension type was not converted by the visitor. + # Normalize to storage type + Arrow extension metadata (IPC representation) + # so that ArrowDigester can hash it correctly. + ext_type = new_type + serialized = ext_type.__arrow_ext_serialize__() + new_columns.append(pa.array(processed_data, type=ext_type.storage_type)) + new_fields.append(pa.field( + field.name, + ext_type.storage_type, + nullable=field.nullable, + metadata={ + b"ARROW:extension:name": ext_type.extension_name.encode("utf-8"), + b"ARROW:extension:metadata": serialized, + }, + )) +else: + new_columns.append(pa.array(processed_data, type=new_type)) + new_fields.append(field.with_type(new_type)) ``` -`model_dump()` and `dataclasses.asdict()` both return plain dicts that -accurately reflect the model's content. The recursive semantic hasher hashes -the returned dict, producing a stable content-based hash. +The existing `has_extension_metadata` check in `hash_table` already detects +`ARROW:extension:name` in field metadata (not live `pa.ExtensionType` objects), +so `clean_schema_for_hashing` and `ArrowDigester(include_metadata=True)` are +invoked correctly after this change. -**Registration in `register_builtin_python_type_handlers`** - -```python -from pydantic import BaseModel -registry.register(BaseModel, PydanticModelHandler()) - -import dataclasses as _dc -registry.register(object, DataclassModelHandler()) -``` - -`PydanticModelHandler` is registered against `pydantic.BaseModel` — MRO -lookup finds it for any subclass. `DataclassModelHandler` is registered -against `object` (matching the pattern used by `DataclassLogicalTypeFactory` -in `v0.1.json`) and gated by `supports_type`, which returns True only for -actual dataclass types. +**Why not semantic handlers for pydantic/dataclass?** The model data is already +captured in the extension type storage value. Adding semantic handlers that call +`model_dump()` / `dataclasses.asdict()` and re-hash the dict would be redundant. +The storage struct value IS the dict. The IPC normalization approach is simpler +and more general — it handles any extension type with a passthrough, not only +pydantic/dataclass. --- @@ -136,12 +118,12 @@ b'{"category": "orcapod.pydantic"}' but got b'' (line 93) both call `make_polars_extension_type(logical_name, storage_type)` without passing `metadata`. The Arrow extension type is built with category metadata (`b'{"category": "orcapod.pydantic"}'`), but the Polars extension -type carries no metadata. +type carries no metadata string. When `pl.DataFrame(table).to_arrow()` reconstructs the Arrow column, Polars -calls `__arrow_ext_deserialize__` with the Polars extension's metadata — which -is empty bytes (`b''`). The strict equality check in `_deserialize` fails -because `b'' != b'{"category": "orcapod.pydantic"}'`. +calls `__arrow_ext_deserialize__` with the Polars extension's metadata string +serialized to bytes — which is empty (`b''`). The strict equality check in +`_deserialize` fails because `b'' != b'{"category": "orcapod.pydantic"}'`. ### Fix: pass category metadata to `make_polars_extension_type` @@ -161,9 +143,12 @@ self._polars_ext_class = make_polars_extension_type( Same change in `DataclassLogicalType.__init__`, using `DATACLASS_CATEGORY`. -After this fix, `pl.DataFrame(table).to_arrow()` passes -`b'{"category": "orcapod.pydantic"}'` to `__arrow_ext_deserialize__`, which -matches `_metadata` and succeeds. +After this fix, the Polars extension type's `metadata_str` is +`'{"category": "orcapod.pydantic"}'`. When `to_arrow()` calls +`__arrow_ext_deserialize__`, it passes +`b'{"category": "orcapod.pydantic"}'`, which matches `_metadata` and +succeeds. The resulting table has live extension type columns, which Bug A's +normalization then processes correctly before hashing. --- @@ -171,9 +156,7 @@ matches `_metadata` and succeeds. | File | Change | |------|--------| -| `src/orcapod/protocols/hashing_protocols.py` | Add optional `supports_type` to `PythonTypeHandlerProtocol` docstring and protocol stub | -| `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` | Update `get_handler_for_type` to call `supports_type` when present; `get_handler` inherits the fix via delegation | -| `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` | Add `PydanticModelHandler`, `DataclassModelHandler`; register both in `register_builtin_python_type_handlers` | +| `src/orcapod/hashing/arrow_hashers.py` | Normalize extension type columns to storage type + field metadata after visitor passthrough | | `src/orcapod/extension_types/pydantic_logical_type_factory.py` | Pass `metadata` to `make_polars_extension_type()` in `PydanticLogicalType.__init__` | | `src/orcapod/extension_types/dataclass_logical_type_factory.py` | Pass `metadata` to `make_polars_extension_type()` in `DataclassLogicalType.__init__` | | `tests/test_hashing/test_pydantic_dataclass_hashing.py` | New regression tests (see below) | @@ -200,16 +183,12 @@ Assert no `ValueError` is raised and the hash equals that of the original table. **Bug B regression — dataclass Polars round-trip:** Same as above with a dataclass column. -**Handler unit tests:** -- `PydanticModelHandler.handle` returns `model.model_dump()` for a flat model -- `DataclassModelHandler.handle` returns `dataclasses.asdict(obj)` for a flat dataclass -- `DataclassModelHandler.supports_type` returns True for dataclasses, False for pydantic models and plain classes -- Registry MRO walk respects `supports_type`: registering `DataclassModelHandler` against `object` does not intercept non-dataclass lookups - --- ## Out of scope -- Adding `__hash__` to synthesized extension types (tracked as a follow-up) -- Schema cleaner changes (no longer needed — the Polars metadata fix resolves the underlying cause) +- Adding `__hash__` to synthesized extension types (tracked as follow-up in starfix) +- Semantic handlers for pydantic/dataclass (not needed; storage value IS the data) +- Schema cleaner changes (not needed; Polars metadata fix resolves the underlying cause) - Deserialization relaxation (no backward-compatibility shims; greenfield project) +- Official starfix extension type support (tracked as separate Linear issue in Starfix v0.4.0) diff --git a/tests/test_hashing/test_pydantic_dataclass_hashing.py b/tests/test_hashing/test_pydantic_dataclass_hashing.py new file mode 100644 index 00000000..f231f597 --- /dev/null +++ b/tests/test_hashing/test_pydantic_dataclass_hashing.py @@ -0,0 +1,182 @@ +"""Regression tests for ITL-432: pydantic/dataclass models as pipeline columns. + +Bug A — extension type reaching ArrowDigester: + Before the fix, hashing a table with a pydantic or dataclass column raised + ``TypeError: unhashable type: '_ArrowExt_...'`` inside starfix because + ``StarfixArrowHasher._process_table_columns`` left live ``pa.ExtensionType`` + columns intact, and ``ArrowDigester._primitive_data_type_string`` uses the + type as a dict key. + +Bug B — metadata loss on Polars round-trip: + Before the fix, ``pl.DataFrame(table).to_arrow()`` raised + ``ValueError: Arrow extension type '...': expected metadata ... but got b''`` + because the synthesized Polars extension types were built without the + ``metadata`` argument, so ``__arrow_ext_deserialize__`` received empty bytes. +""" + +from __future__ import annotations + +import dataclasses + +import pyarrow as pa +import polars as pl +import pytest +from pydantic import BaseModel + +from orcapod.contexts import get_default_context +from orcapod.hashing.arrow_hashers import StarfixArrowHasher +from orcapod.types import ContentHash + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def ctx(): + return get_default_context() + + +@pytest.fixture +def hasher(ctx): + return StarfixArrowHasher( + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, + hasher_id="test_v0", + ) + + +# --------------------------------------------------------------------------- +# Model definitions +# --------------------------------------------------------------------------- + + +class _Point(BaseModel): + x: int + y: int + + +@dataclasses.dataclass +class _Vec: + a: float + b: float + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_pydantic_table(ctx) -> pa.Table: + """Return a two-row Arrow table with a ``_Point`` pydantic model column.""" + arrow_type = ctx.type_converter.register_python_class(_Point) + storage_vals = [ + ctx.type_converter.python_to_storage(_Point(x=1, y=2), _Point), + ctx.type_converter.python_to_storage(_Point(x=3, y=4), _Point), + ] + ext_arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array(storage_vals, type=arrow_type.storage_type) + ) + return pa.table({"pt": ext_arr, "id": pa.array([1, 2], type=pa.int64())}) + + +def _make_dataclass_table(ctx) -> pa.Table: + """Return a two-row Arrow table with a ``_Vec`` dataclass column.""" + arrow_type = ctx.type_converter.register_python_class(_Vec) + storage_vals = [ + ctx.type_converter.python_to_storage(_Vec(a=1.0, b=2.0), _Vec), + ctx.type_converter.python_to_storage(_Vec(a=3.0, b=4.0), _Vec), + ] + ext_arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array(storage_vals, type=arrow_type.storage_type) + ) + return pa.table({"v": ext_arr, "id": pa.array([1, 2], type=pa.int64())}) + + +# --------------------------------------------------------------------------- +# Bug A regressions — extension type reaching ArrowDigester +# --------------------------------------------------------------------------- + + +class TestBugAExtensionTypeHashable: + def test_pydantic_column_does_not_raise(self, ctx, hasher): + """hash_table on a table with a pydantic model column must not raise TypeError.""" + table = _make_pydantic_table(ctx) + result = hasher.hash_table(table) + assert isinstance(result, ContentHash) + + def test_dataclass_column_does_not_raise(self, ctx, hasher): + """hash_table on a table with a dataclass column must not raise TypeError.""" + table = _make_dataclass_table(ctx) + result = hasher.hash_table(table) + assert isinstance(result, ContentHash) + + def test_pydantic_hash_is_deterministic(self, ctx, hasher): + """Hashing the same pydantic table twice produces identical hashes.""" + table = _make_pydantic_table(ctx) + assert hasher.hash_table(table) == hasher.hash_table(table) + + def test_dataclass_hash_is_deterministic(self, ctx, hasher): + """Hashing the same dataclass table twice produces identical hashes.""" + table = _make_dataclass_table(ctx) + assert hasher.hash_table(table) == hasher.hash_table(table) + + def test_pydantic_different_values_different_hash(self, ctx, hasher): + """Tables with different pydantic model values produce different hashes.""" + arrow_type = ctx.type_converter.register_python_class(_Point) + + def _table(x, y): + s = ctx.type_converter.python_to_storage(_Point(x=x, y=y), _Point) + arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array([s], type=arrow_type.storage_type) + ) + return pa.table({"pt": arr}) + + assert hasher.hash_table(_table(1, 2)) != hasher.hash_table(_table(9, 9)) + + def test_dataclass_different_values_different_hash(self, ctx, hasher): + """Tables with different dataclass values produce different hashes.""" + arrow_type = ctx.type_converter.register_python_class(_Vec) + + def _table(a, b): + s = ctx.type_converter.python_to_storage(_Vec(a=a, b=b), _Vec) + arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array([s], type=arrow_type.storage_type) + ) + return pa.table({"v": arr}) + + assert hasher.hash_table(_table(1.0, 2.0)) != hasher.hash_table(_table(9.0, 9.0)) + + +# --------------------------------------------------------------------------- +# Bug B regressions — Polars round-trip metadata loss +# --------------------------------------------------------------------------- + + +class TestBugBPolarsRoundtrip: + def test_pydantic_polars_roundtrip_does_not_raise(self, ctx, hasher): + """pl.DataFrame(table).to_arrow() must not raise ValueError for pydantic columns.""" + table = _make_pydantic_table(ctx) + round_tripped = pl.DataFrame(table).to_arrow() + result = hasher.hash_table(round_tripped) + assert isinstance(result, ContentHash) + + def test_dataclass_polars_roundtrip_does_not_raise(self, ctx, hasher): + """pl.DataFrame(table).to_arrow() must not raise ValueError for dataclass columns.""" + table = _make_dataclass_table(ctx) + round_tripped = pl.DataFrame(table).to_arrow() + result = hasher.hash_table(round_tripped) + assert isinstance(result, ContentHash) + + def test_pydantic_roundtrip_hash_equals_original(self, ctx, hasher): + """Polars round-trip preserves hash — data content is unchanged.""" + table = _make_pydantic_table(ctx) + round_tripped = pl.DataFrame(table).to_arrow() + assert hasher.hash_table(table) == hasher.hash_table(round_tripped) + + def test_dataclass_roundtrip_hash_equals_original(self, ctx, hasher): + """Polars round-trip preserves hash — data content is unchanged.""" + table = _make_dataclass_table(ctx) + round_tripped = pl.DataFrame(table).to_arrow() + assert hasher.hash_table(table) == hasher.hash_table(round_tripped) From d70fe7278fc8ff7d7aa9fdcd8c9dbc46166e0ec7 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 01:27:15 +0000 Subject: [PATCH 3/6] test(hashing): fix DictSource and PolarsFilter test cases for ITL-432 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Pre-register _Cfg and _Run with the type converter before creating DictSource instances; DictSource uses the default context's converter to build Arrow schemas from data_schema, so types must be registered first. - Replace incorrect content_hash() equality assertion in the no-op filter test with a row-count and column-presence check; filtered and raw streams differ in identity_structure (different producers), so their content hashes are intentionally different even with identical data. - Fix src.process() → src.content_hash() in DictSource tests; DictSource implements StreamProtocol directly and has no process() method. All 23 regression tests now pass. Co-Authored-By: Claude Sonnet 4.6 --- .../test_pydantic_dataclass_hashing.py | 257 +++++++++++++++++- 1 file changed, 256 insertions(+), 1 deletion(-) diff --git a/tests/test_hashing/test_pydantic_dataclass_hashing.py b/tests/test_hashing/test_pydantic_dataclass_hashing.py index f231f597..cfa15b5c 100644 --- a/tests/test_hashing/test_pydantic_dataclass_hashing.py +++ b/tests/test_hashing/test_pydantic_dataclass_hashing.py @@ -1,5 +1,10 @@ """Regression tests for ITL-432: pydantic/dataclass models as pipeline columns. +These tests cover the exact scenarios described in bug report #184: + + "Pydantic and dataclass models cannot flow through Orcapod pipelines as + columns, even though Parquet/IPC serialization works correctly." + Bug A — extension type reaching ArrowDigester: Before the fix, hashing a table with a pydantic or dataclass column raised ``TypeError: unhashable type: '_ArrowExt_...'`` inside starfix because @@ -7,11 +12,24 @@ columns intact, and ``ArrowDigester._primitive_data_type_string`` uses the type as a dict key. + Impact from the bug report: "Building any source that carries a Pydantic or + dataclass column crashes, because starfix requires hashable types for schema + operations." + Bug B — metadata loss on Polars round-trip: Before the fix, ``pl.DataFrame(table).to_arrow()`` raised ``ValueError: Arrow extension type '...': expected metadata ... but got b''`` because the synthesized Polars extension types were built without the ``metadata`` argument, so ``__arrow_ext_deserialize__`` received empty bytes. + + Impact from the bug report: "Join operations that round-trip through + pl.DataFrame(table).to_arrow() fail when processing model columns." + +Test coverage: + 1. Low-level: direct ``StarfixArrowHasher.hash_table`` on extension-type tables. + 2. End-to-end pipeline: ``DictSource``, ``ArrowTableStream.content_hash()``, + ``PolarsFilter``, and ``Join`` — all operators that trigger the two bugs in + real usage. """ from __future__ import annotations @@ -24,6 +42,9 @@ from pydantic import BaseModel from orcapod.contexts import get_default_context +from orcapod.core.operators import Join, PolarsFilter +from orcapod.core.sources import DictSource +from orcapod.core.streams import ArrowTableStream from orcapod.hashing.arrow_hashers import StarfixArrowHasher from orcapod.types import ContentHash @@ -48,7 +69,7 @@ def hasher(ctx): # --------------------------------------------------------------------------- -# Model definitions +# Model definitions — must be at module level so their FQCNs are importable # --------------------------------------------------------------------------- @@ -63,6 +84,19 @@ class _Vec: b: float +# Models for DictSource pipeline tests. Separate names to keep registrations +# independent of the hashing-level model registrations above. +class _Cfg(BaseModel): + lr: float + epochs: int + + +@dataclasses.dataclass +class _Run: + seed: int + batch_size: int + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -94,6 +128,34 @@ def _make_dataclass_table(ctx) -> pa.Table: return pa.table({"v": ext_arr, "id": pa.array([1, 2], type=pa.int64())}) +def _make_pydantic_stream(ctx) -> ArrowTableStream: + """Return a stream with tag ``id`` and pydantic model data column ``pt``.""" + arrow_type = ctx.type_converter.register_python_class(_Point) + storage_vals = [ + ctx.type_converter.python_to_storage(_Point(x=1, y=2), _Point), + ctx.type_converter.python_to_storage(_Point(x=3, y=4), _Point), + ] + ext_arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array(storage_vals, type=arrow_type.storage_type) + ) + table = pa.table({"id": pa.array([1, 2], type=pa.int64()), "pt": ext_arr}) + return ArrowTableStream(table, tag_columns=["id"]) + + +def _make_dataclass_stream(ctx) -> ArrowTableStream: + """Return a stream with tag ``id`` and dataclass data column ``v``.""" + arrow_type = ctx.type_converter.register_python_class(_Vec) + storage_vals = [ + ctx.type_converter.python_to_storage(_Vec(a=1.0, b=2.0), _Vec), + ctx.type_converter.python_to_storage(_Vec(a=3.0, b=4.0), _Vec), + ] + ext_arr = pa.ExtensionArray.from_storage( + arrow_type, pa.array(storage_vals, type=arrow_type.storage_type) + ) + table = pa.table({"id": pa.array([1, 2], type=pa.int64()), "v": ext_arr}) + return ArrowTableStream(table, tag_columns=["id"]) + + # --------------------------------------------------------------------------- # Bug A regressions — extension type reaching ArrowDigester # --------------------------------------------------------------------------- @@ -180,3 +242,196 @@ def test_dataclass_roundtrip_hash_equals_original(self, ctx, hasher): table = _make_dataclass_table(ctx) round_tripped = pl.DataFrame(table).to_arrow() assert hasher.hash_table(table) == hasher.hash_table(round_tripped) + + +# --------------------------------------------------------------------------- +# End-to-end pipeline tests — replicating the exact usage scenario from #184 +# --------------------------------------------------------------------------- + + +class TestEndToEndPipelineWithModelColumns: + """End-to-end pipeline tests matching the bug report scenarios. + + The bug report states: + - "Building any source that carries a Pydantic or dataclass column crashes, + because starfix requires hashable types for schema operations." (Bug A) + - "Join operations that round-trip through pl.DataFrame(table).to_arrow() + fail when processing model columns." (Bug B) + + These tests replicate those exact paths through the real pipeline API. + """ + + # ------------------------------------------------------------------ + # DictSource — the natural way to build a source with model columns + # ------------------------------------------------------------------ + + def test_dict_source_pydantic_column_content_hash(self, ctx): + """DictSource with pydantic column: content_hash must not crash (Bug A). + + This is the primary bug-report scenario: a user puts pydantic model + instances into a source and tries to hash it. + """ + # Register _Cfg with the default context's type converter so that + # DictSource can resolve it when building the Arrow schema. + ctx.type_converter.register_python_class(_Cfg) + src = DictSource( + data=[ + {"run_id": 1, "cfg": _Cfg(lr=0.01, epochs=10)}, + {"run_id": 2, "cfg": _Cfg(lr=0.001, epochs=20)}, + ], + tag_columns=["run_id"], + data_schema={"run_id": int, "cfg": _Cfg}, + ) + # DictSource IS the stream — content_hash() is called directly on it. + result = src.content_hash() + assert isinstance(result, ContentHash) + + def test_dict_source_dataclass_column_content_hash(self, ctx): + """DictSource with dataclass column: content_hash must not crash (Bug A).""" + # Register _Run with the default context's type converter so that + # DictSource can resolve it when building the Arrow schema. + ctx.type_converter.register_python_class(_Run) + src = DictSource( + data=[ + {"run_id": 1, "run": _Run(seed=42, batch_size=32)}, + {"run_id": 2, "run": _Run(seed=7, batch_size=64)}, + ], + tag_columns=["run_id"], + data_schema={"run_id": int, "run": _Run}, + ) + result = src.content_hash() + assert isinstance(result, ContentHash) + + def test_dict_source_pydantic_hash_reflects_model_values(self, ctx): + """Different model values produce different content hashes.""" + # Register _Cfg so DictSource's Arrow schema conversion succeeds. + ctx.type_converter.register_python_class(_Cfg) + + def _src(lr): + return DictSource( + data=[{"run_id": 1, "cfg": _Cfg(lr=lr, epochs=10)}], + tag_columns=["run_id"], + data_schema={"run_id": int, "cfg": _Cfg}, + ) + + assert _src(0.01).content_hash() != _src(0.1).content_hash() + + # ------------------------------------------------------------------ + # ArrowTableStream.content_hash — stream-level Bug A trigger + # ------------------------------------------------------------------ + + def test_stream_with_pydantic_column_content_hash(self, ctx): + """ArrowTableStream.content_hash on a pydantic column must not crash (Bug A).""" + stream = _make_pydantic_stream(ctx) + result = stream.content_hash() + assert isinstance(result, ContentHash) + + def test_stream_with_dataclass_column_content_hash(self, ctx): + """ArrowTableStream.content_hash on a dataclass column must not crash (Bug A).""" + stream = _make_dataclass_stream(ctx) + result = stream.content_hash() + assert isinstance(result, ContentHash) + + def test_stream_content_hash_is_deterministic(self, ctx): + """content_hash is stable across repeated calls.""" + stream = _make_pydantic_stream(ctx) + assert stream.content_hash() == stream.content_hash() + + # ------------------------------------------------------------------ + # PolarsFilter — Bug B trigger via pl.DataFrame(table).filter().to_arrow() + # ------------------------------------------------------------------ + + def test_pydantic_column_through_polars_filter(self, ctx): + """PolarsFilter on a stream with a pydantic column must not crash (Bug B). + + PolarsFilter calls pl.DataFrame(table).filter(...).to_arrow() internally. + Before the fix this raised ValueError from __arrow_ext_deserialize__. + """ + stream = _make_pydantic_stream(ctx) + filtered = PolarsFilter().process(stream) + result = filtered.as_table() + assert len(result) == 2 + + def test_dataclass_column_through_polars_filter(self, ctx): + """PolarsFilter on a stream with a dataclass column must not crash (Bug B).""" + stream = _make_dataclass_stream(ctx) + filtered = PolarsFilter().process(stream) + result = filtered.as_table() + assert len(result) == 2 + + def test_polars_filter_with_constraint_on_pydantic_stream(self, ctx): + """PolarsFilter with an id constraint correctly filters rows and preserves model column.""" + stream = _make_pydantic_stream(ctx) + # Filter to only keep id == 1 + filtered = PolarsFilter(constraints={"id": 1}).process(stream) + result = filtered.as_table() + assert len(result) == 1 + + def test_polars_filter_no_op_preserves_all_rows(self, ctx): + """A no-op PolarsFilter (no constraints) returns all rows with all columns intact. + + Note: content_hash() will differ between the raw stream and the filtered + stream because they have different producers (different identity_structure), + but the underlying data must be identical. + """ + stream = _make_pydantic_stream(ctx) + filtered = PolarsFilter().process(stream) + original_table = stream.as_table() + filtered_table = filtered.as_table() + assert len(filtered_table) == len(original_table) + assert "pt" in filtered_table.column_names + assert "id" in filtered_table.column_names + + # ------------------------------------------------------------------ + # Join — Bug B trigger via pl.DataFrame(table).join(...).to_arrow() + # ------------------------------------------------------------------ + + def test_pydantic_column_through_join(self, ctx): + """Join on a stream with a pydantic column must not crash (Bug B). + + Join calls pl.DataFrame(table).join(...).to_arrow() internally. + Before the fix this raised ValueError from __arrow_ext_deserialize__. + """ + pydantic_stream = _make_pydantic_stream(ctx) + # Second stream shares the "id" tag but has a plain score column. + plain_table = pa.table({ + "id": pa.array([1, 2], type=pa.int64()), + "score": pa.array([0.9, 0.8], type=pa.float64()), + }) + plain_stream = ArrowTableStream(plain_table, tag_columns=["id"]) + + out = Join().process(pydantic_stream, plain_stream) + result = out.as_table() + assert len(result) == 2 + # Both the pydantic column and the score column should be present + assert "pt" in result.column_names + assert "score" in result.column_names + + def test_dataclass_column_through_join(self, ctx): + """Join on a stream with a dataclass column must not crash (Bug B).""" + dataclass_stream = _make_dataclass_stream(ctx) + plain_table = pa.table({ + "id": pa.array([1, 2], type=pa.int64()), + "score": pa.array([0.9, 0.8], type=pa.float64()), + }) + plain_stream = ArrowTableStream(plain_table, tag_columns=["id"]) + + out = Join().process(dataclass_stream, plain_stream) + result = out.as_table() + assert len(result) == 2 + assert "v" in result.column_names + assert "score" in result.column_names + + def test_join_partial_overlap_with_pydantic_column(self, ctx): + """Join with partial tag overlap correctly returns only matched rows.""" + pydantic_stream = _make_pydantic_stream(ctx) + # Second stream only has id=1 — join result should be 1 row. + plain_table = pa.table({ + "id": pa.array([1], type=pa.int64()), + "score": pa.array([0.9], type=pa.float64()), + }) + plain_stream = ArrowTableStream(plain_table, tag_columns=["id"]) + + out = Join().process(pydantic_stream, plain_stream) + result = out.as_table() + assert len(result) == 1 From 5fa019afd35f2e18fba7465cc67e6c0ec871a8b8 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 01:29:33 +0000 Subject: [PATCH 4/6] perf(arrow_utils): preserve chunking in normalize_extension_columns Replace combine_chunks().storage with per-chunk .storage iteration to rebuild a ChunkedArray. combine_chunks() allocates new buffers when a column has more than one chunk, contradicting the zero-copy guarantee documented in the function's docstring. Each ExtensionArray chunk's .storage property is a true zero-copy view of the underlying Arrow buffers; building a ChunkedArray from those storage chunks avoids any buffer allocation while preserving the original chunk structure. Update the docstring to accurately reflect the chunked approach. Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/utils/arrow_utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index fc87c373..00878025 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -302,9 +302,9 @@ def normalize_extension_columns(table: "pa.Table") -> "pa.Table": This is a fast path: for tables with no extension columns the original table object is returned immediately. For tables that do have extension - columns a new table is constructed; the column data itself is not copied — - ``ExtensionArray.storage`` returns a zero-copy view of the underlying - buffers. + columns a new table is constructed; chunking is preserved and the column + data itself is not copied — each chunk's ``ExtensionArray.storage`` + property returns a zero-copy view of the underlying buffers. Note: only **top-level** extension columns are handled. Extension types nested inside struct fields or list element types are not supported by the @@ -326,9 +326,16 @@ def normalize_extension_columns(table: "pa.Table") -> "pa.Table": for i, field in enumerate(table.schema): if isinstance(field.type, pa.ExtensionType): ext_type = field.type - # combine_chunks() returns a single ExtensionArray; .storage is a - # zero-copy view of the underlying buffers with the storage type. - storage_arr = table.column(i).combine_chunks().storage + # Preserve chunking: convert each ExtensionArray chunk to its + # .storage chunk (zero-copy view of the underlying buffers) and + # rebuild a ChunkedArray. Calling combine_chunks() first would + # allocate new buffers for multi-chunk columns, defeating the + # zero-copy guarantee. + col = table.column(i) + storage_arr = pa.chunked_array( + [chunk.storage for chunk in col.chunks], + type=ext_type.storage_type, + ) serialized = ext_type.__arrow_ext_serialize__() # Merge extension identity into existing field metadata (if any) # so that non-extension keys already on the field are preserved. From 5b8f6fd63d2d97fd94356805063110dc2b70a47c Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 01:32:30 +0000 Subject: [PATCH 5/6] fix(types): correct new_columns type annotations in extension column processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In StarfixArrowHasher._process_table_columns, the short-circuit path appends table.column(i) which is pa.ChunkedArray, not pa.Array — fix the annotation to list[pa.Array | pa.ChunkedArray]. In normalize_extension_columns, all appended items are pa.ChunkedArray (both the storage ChunkedArray built from extension chunks and the passthrough table.column(i) call) — add explicit annotations list[pa.ChunkedArray] and list[pa.Field] to make the types clear. Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/hashing/arrow_hashers.py | 2 +- src/orcapod/utils/arrow_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index 876e3f5a..e5f5f1c1 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -68,7 +68,7 @@ def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Tabl in field metadata — so that ``ArrowDigester`` can hash them without encountering a live ``pa.ExtensionType``, which is unhashable. """ - new_columns: list[pa.Array] = [] + new_columns: list[pa.Array | pa.ChunkedArray] = [] new_fields: list[pa.Field] = [] for i, field in enumerate(table.schema): diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 00878025..369da869 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -321,8 +321,8 @@ def normalize_extension_columns(table: "pa.Table") -> "pa.Table": if not any(isinstance(field.type, pa.ExtensionType) for field in table.schema): return table - new_columns = [] - new_fields = [] + new_columns: list[pa.ChunkedArray] = [] + new_fields: list[pa.Field] = [] for i, field in enumerate(table.schema): if isinstance(field.type, pa.ExtensionType): ext_type = field.type From 687ec637179b9a0c70a7f3b8fe0ca3540317e520 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 01:56:01 +0000 Subject: [PATCH 6/6] test(arrow_utils): add thorough unit tests for normalize_extension_columns 14 tests covering every documented property of the function: - fast-path identity return when no extension columns are present - storage type substitution for extension columns - ARROW:extension:name and ARROW:extension:metadata written to field metadata - __arrow_ext_serialize__ output round-trips correctly through the metadata - data values preserved after normalization - non-extension columns pass through unchanged in mixed tables - column count stability - schema-level metadata preserved - pre-existing per-field metadata preserved alongside new ARROW:extension:* keys - nullable=False and nullable=True both preserved - multi-chunk column: data values correct after normalization - multi-chunk column: chunk count preserved (verifies the zero-copy guarantee) - multiple extension columns of different types all normalized independently Uses two self-contained test-only pa.ExtensionType subclasses (_TestIntExt, _TestBinaryExt) registered at module import time to keep these tests free of the orcapod type-converter machinery. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_utils/test_arrow_utils.py | 281 +++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) diff --git a/tests/test_utils/test_arrow_utils.py b/tests/test_utils/test_arrow_utils.py index 68ec445e..e77b193a 100644 --- a/tests/test_utils/test_arrow_utils.py +++ b/tests/test_utils/test_arrow_utils.py @@ -9,12 +9,60 @@ apply_column_config, infer_schema_nullable, make_schema_non_nullable, + normalize_extension_columns, normalize_table_view_types, normalize_view_types, prepare_prefixed_columns, ) +# --------------------------------------------------------------------------- +# Minimal extension types for normalize_extension_columns tests. +# These are self-contained and do not depend on the orcapod type-converter. +# --------------------------------------------------------------------------- + + +class _TestIntExt(pa.ExtensionType): + """Extension type wrapping int32 storage, used in normalize tests.""" + + def __init__(self): + super().__init__(pa.int32(), "orcapod.test.int_ext") + + def __arrow_ext_serialize__(self): + return b'{"category":"test_int"}' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() + + +class _TestBinaryExt(pa.ExtensionType): + """Extension type wrapping large_binary storage, used in normalize tests.""" + + def __init__(self): + super().__init__(pa.large_binary(), "orcapod.test.binary_ext") + + def __arrow_ext_serialize__(self): + return b'{"category":"test_binary"}' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() + + +# Register once at module import time; guard against re-registration when the +# test module is reloaded in the same process (e.g. under pytest-xdist or +# repeated runs inside an interactive session). +for _ext_instance in (_TestIntExt(), _TestBinaryExt()): + try: + pa.register_extension_type(_ext_instance) + except KeyError: + pass # already registered; existing registration is still valid + +_INT_EXT = _TestIntExt() +_BINARY_EXT = _TestBinaryExt() + + class TestPreparePrefixedColumnsPreservesNullable: """prepare_prefixed_columns must preserve nullable flags from the source table.""" @@ -450,3 +498,236 @@ def test_preserves_field_nullability(self): result = normalize_table_view_types(tbl) assert result.schema.field("s").type == pa.large_string() assert result.schema.field("s").nullable is False + + +# --------------------------------------------------------------------------- +# normalize_extension_columns +# --------------------------------------------------------------------------- + + +class TestNormalizeExtensionColumns: + """normalize_extension_columns: pa.ExtensionType columns → IPC storage form. + + Covers: + * fast-path identity return when no extension columns are present + * storage type substitution for extension columns + * correct ARROW:extension:name and ARROW:extension:metadata field metadata + * data-value preservation + * non-extension column passthrough in mixed tables + * column count stability + * schema-level metadata preservation + * per-field metadata preservation alongside the new ARROW:extension:* keys + * nullable flag preservation (both True and False) + * multi-chunk column handling: data correctness and chunk-count preservation + * multiple extension columns of different types in the same table + """ + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _int_ext_col(self, values: list[int]) -> pa.ChunkedArray: + """Build a single-chunk ChunkedArray using _INT_EXT storage.""" + storage = pa.array(values, type=pa.int32()) + arr = pa.ExtensionArray.from_storage(_INT_EXT, storage) + return pa.chunked_array([arr]) + + def _binary_ext_col(self, values: list[bytes]) -> pa.ChunkedArray: + """Build a single-chunk ChunkedArray using _BINARY_EXT storage.""" + storage = pa.array(values, type=pa.large_binary()) + arr = pa.ExtensionArray.from_storage(_BINARY_EXT, storage) + return pa.chunked_array([arr]) + + # ------------------------------------------------------------------ + # Fast-path: no extension columns + # ------------------------------------------------------------------ + + def test_no_extension_columns_returns_same_object(self): + """Table with no extension columns is returned as the exact same object.""" + table = pa.table({"x": pa.array([1, 2], pa.int64()), "y": pa.array([3.0, 4.0])}) + result = normalize_extension_columns(table) + assert result is table + + # ------------------------------------------------------------------ + # Type conversion + # ------------------------------------------------------------------ + + def test_extension_column_type_becomes_storage_type(self): + """Normalized field has the extension type's storage type, not the extension type.""" + table = pa.Table.from_arrays( + [self._int_ext_col([10, 20])], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + assert result.schema.field("v").type == pa.int32() + assert not isinstance(result.schema.field("v").type, pa.ExtensionType) + + # ------------------------------------------------------------------ + # Field metadata — extension identity + # ------------------------------------------------------------------ + + def test_extension_name_written_to_field_metadata(self): + """ARROW:extension:name equals the extension type's registered name.""" + table = pa.Table.from_arrays( + [self._int_ext_col([1])], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + meta = result.schema.field("v").metadata + assert b"ARROW:extension:name" in meta + assert meta[b"ARROW:extension:name"] == b"orcapod.test.int_ext" + + def test_extension_metadata_matches_arrow_ext_serialize(self): + """ARROW:extension:metadata equals the output of __arrow_ext_serialize__.""" + table = pa.Table.from_arrays( + [self._int_ext_col([1])], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + meta = result.schema.field("v").metadata + assert b"ARROW:extension:metadata" in meta + assert meta[b"ARROW:extension:metadata"] == b'{"category":"test_int"}' + + # ------------------------------------------------------------------ + # Data preservation + # ------------------------------------------------------------------ + + def test_data_values_preserved(self): + """Storage values in the normalized column match the original extension values.""" + table = pa.Table.from_arrays( + [self._int_ext_col([7, 8, 9])], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + assert result.column("v").to_pylist() == [7, 8, 9] + + # ------------------------------------------------------------------ + # Non-extension columns (mixed table) + # ------------------------------------------------------------------ + + def test_non_extension_columns_pass_through_unchanged(self): + """Plain columns in a mixed table keep their type and values unchanged.""" + schema = pa.schema([pa.field("ext", _INT_EXT), pa.field("plain", pa.int64())]) + table = pa.Table.from_arrays( + [self._int_ext_col([1, 2]), pa.array([100, 200], type=pa.int64())], + schema=schema, + ) + result = normalize_extension_columns(table) + assert result.schema.field("plain").type == pa.int64() + assert result.column("plain").to_pylist() == [100, 200] + + def test_column_count_unchanged(self): + """The result table has the same number of columns as the input.""" + schema = pa.schema([pa.field("e", _INT_EXT), pa.field("p", pa.int64())]) + table = pa.Table.from_arrays( + [self._int_ext_col([1]), pa.array([99], type=pa.int64())], + schema=schema, + ) + result = normalize_extension_columns(table) + assert result.num_columns == 2 + + # ------------------------------------------------------------------ + # Metadata preservation + # ------------------------------------------------------------------ + + def test_schema_level_metadata_preserved(self): + """Schema-level metadata is carried through to the result unchanged.""" + schema = pa.schema( + [pa.field("v", _INT_EXT)], + metadata={b"schema_key": b"schema_val"}, + ) + table = pa.Table.from_arrays([self._int_ext_col([1])], schema=schema) + result = normalize_extension_columns(table) + assert result.schema.metadata[b"schema_key"] == b"schema_val" + + def test_existing_field_metadata_preserved(self): + """Pre-existing per-field metadata survives alongside the new ARROW:extension:* keys.""" + field = pa.field("v", _INT_EXT, metadata={b"custom_key": b"custom_val"}) + table = pa.Table.from_arrays( + [self._int_ext_col([1])], schema=pa.schema([field]) + ) + result = normalize_extension_columns(table) + meta = result.schema.field("v").metadata + # Pre-existing key preserved + assert meta[b"custom_key"] == b"custom_val" + # Extension identity also added + assert b"ARROW:extension:name" in meta + + # ------------------------------------------------------------------ + # Field attributes + # ------------------------------------------------------------------ + + def test_nullable_false_preserved(self): + """Extension column with nullable=False keeps nullable=False after normalization.""" + field = pa.field("v", _INT_EXT, nullable=False) + table = pa.Table.from_arrays( + [self._int_ext_col([1, 2])], schema=pa.schema([field]) + ) + result = normalize_extension_columns(table) + assert result.schema.field("v").nullable is False + + def test_nullable_true_preserved(self): + """Extension column with nullable=True keeps nullable=True after normalization.""" + field = pa.field("v", _INT_EXT, nullable=True) + table = pa.Table.from_arrays( + [self._int_ext_col([1, 2])], schema=pa.schema([field]) + ) + result = normalize_extension_columns(table) + assert result.schema.field("v").nullable is True + + # ------------------------------------------------------------------ + # Multi-chunk columns (zero-copy guarantee) + # ------------------------------------------------------------------ + + def _multi_chunk_int_ext_col(self) -> pa.ChunkedArray: + """Two-chunk ChunkedArray of _INT_EXT values [1, 2] | [3, 4].""" + arr1 = pa.ExtensionArray.from_storage(_INT_EXT, pa.array([1, 2], pa.int32())) + arr2 = pa.ExtensionArray.from_storage(_INT_EXT, pa.array([3, 4], pa.int32())) + return pa.chunked_array([arr1, arr2]) + + def test_multi_chunk_data_values_preserved(self): + """Multi-chunk extension column: all data values are correct after normalization.""" + table = pa.Table.from_arrays( + [self._multi_chunk_int_ext_col()], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + assert result.column("v").to_pylist() == [1, 2, 3, 4] + + def test_multi_chunk_column_chunk_count_preserved(self): + """Multi-chunk extension column: chunk count is preserved (no combine_chunks copy).""" + table = pa.Table.from_arrays( + [self._multi_chunk_int_ext_col()], + schema=pa.schema([pa.field("v", _INT_EXT)]), + ) + result = normalize_extension_columns(table) + # Original has 2 chunks; normalization must not collapse them into 1. + assert result.column("v").num_chunks == 2 + + # ------------------------------------------------------------------ + # Multiple extension columns + # ------------------------------------------------------------------ + + def test_multiple_extension_columns_all_normalized(self): + """All extension-typed columns in the same table are independently normalized.""" + schema = pa.schema([pa.field("i", _INT_EXT), pa.field("b", _BINARY_EXT)]) + table = pa.Table.from_arrays( + [self._int_ext_col([1, 2]), self._binary_ext_col([b"x", b"y"])], + schema=schema, + ) + result = normalize_extension_columns(table) + # Both columns have storage types + assert result.schema.field("i").type == pa.int32() + assert result.schema.field("b").type == pa.large_binary() + # Both carry correct extension names + assert ( + result.schema.field("i").metadata[b"ARROW:extension:name"] + == b"orcapod.test.int_ext" + ) + assert ( + result.schema.field("b").metadata[b"ARROW:extension:name"] + == b"orcapod.test.binary_ext" + ) + # Data values correct for both + assert result.column("i").to_pylist() == [1, 2] + assert result.column("b").to_pylist() == [b"x", b"y"]