diff --git a/.github/workflows/run-objective-tests.yml b/.github/workflows/run-objective-tests.yml index b245f6de..76eb19d3 100644 --- a/.github/workflows/run-objective-tests.yml +++ b/.github/workflows/run-objective-tests.yml @@ -2,9 +2,8 @@ name: Run Objective Tests on: push: - branches: [main, dev] + branches: [main, dev, extension-type-system] pull_request: - branches: [main, dev] workflow_dispatch: # Allows manual triggering jobs: diff --git a/.github/workflows/run-postgres-tests.yml b/.github/workflows/run-postgres-tests.yml index 72dcd3b9..65544873 100644 --- a/.github/workflows/run-postgres-tests.yml +++ b/.github/workflows/run-postgres-tests.yml @@ -2,9 +2,8 @@ name: Run PostgreSQL Tests on: push: - branches: [main, dev] + branches: [main, dev, extension-type-system] pull_request: - branches: [main, dev] workflow_dispatch: # Allows manual triggering jobs: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 0f8fe9c5..a29e8526 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -2,9 +2,8 @@ name: Run Tests on: push: - branches: [main, dev] + branches: [main, dev, extension-type-system] pull_request: - branches: [main, dev] workflow_dispatch: # Allows manual triggering jobs: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e20b5573..1a2b2214 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,9 +2,8 @@ name: Tests on: push: - branches: [main, dev] + branches: [main, dev, extension-type-system] pull_request: - branches: [main, dev] jobs: test: diff --git a/CLAUDE.md b/CLAUDE.md index cd9dbe72..0ec257b8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -107,8 +107,7 @@ Remove any optional sections that don't apply rather than leaving them empty. When working on a feature, create and checkout a git branch using the `gitBranchName` returned by the primary Linear issue (e.g. `eywalker/plt-911-add-documentation-for-orcapod-python`). -**Feature branch PRs always target the `dev` branch.** The `dev` → `main` PR is used -for versioning/releases only. +**Feature branch PRs always target the `extension-type-system` branch.** The `extension-type-system` → `dev` → `main` PRs are used for integration and releases. If a feature branch / PR corresponds to multiple Linear issues, list all of them in the PR description body so that Linear's GitHub integration auto-tracks the PR against each diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 0c47a613..8c8572c9 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -999,6 +999,124 @@ Open questions: --- +## `src/orcapod/extension_types/` + +### ET1 — `make_polars_extension_type` cannot accept a storage type containing nested extension types +**Status:** open +**Severity:** medium + +`make_polars_extension_type` computes the Polars storage dtype by calling: +```python +pl.from_arrow(pa.array([], type=arrow_storage_type)).dtype +``` +This fails with `ArrowNotImplementedError: extension` when `arrow_storage_type` is a struct +(or list) whose fields include any `pa.ExtensionType` node — for example, a dataclass whose +fields include `uuid.UUID` (stored as `orcapod.uuid` extension over `pa.large_binary()`). + +Polars's Arrow IPC bridge handles top-level extension types via `pl.BaseExtension`, but has no +path for extension types *nested inside* a struct at dtype-inference time. + +**Workaround:** `register_python_class` and `register_storage_type` both uphold a +*storage-safe* invariant: the returned type may be a `pa.ExtensionType` at the top level, +but struct fields and list value types at any depth are always plain (non-extension) types. +`DataclassLogicalTypeFactory.create_for_python_type` strips the top-level extension type +with a one-liner (`if isinstance(arrow_type, pa.ExtensionType): arrow_type = arrow_type.storage_type`) +before inserting it into the struct, so the struct passed to `make_polars_extension_type` +and `pa.Table.from_pylist` never contains nested extension types. The private +`_strip_ext_to_storage` recursive helper was removed in PLT-1720; the stripping is now +trivially correct because the storage-safe invariant guarantees `.storage_type` is always +already clean. + +**Also affects `pa.Table.from_pylist`:** the same restriction applies to PyArrow's +`pa.Table.from_pylist` (and `pa.array`) — neither can build an array from a struct type +whose fields are `pa.ExtensionType` nodes, for the same underlying reason. The stripping +in `create_for_python_type` fixes both issues simultaneously. + +**Polars round-trip fidelity:** once the storage struct contains only plain types (no +nested extension types), the full Arrow → Polars → Arrow round-trip for the *outermost* +extension type is faithful: extension name, metadata bytes, and storage struct are all +preserved. Only the inner field schema (already stripped) is absent. + +**Fix needed:** Once PyArrow (and Polars) support nested extension types natively in struct +construction and Arrow↔Polars conversion, the stripping one-liner in `create_for_python_type` +can be removed and `make_polars_extension_type` can accept extension-typed storage directly. +Track upstream PyArrow / Polars issues. + +### ET2 — Top-level `list[T]` / `dict[K, V]` columns lose extension-type schema metadata when `T`/`V` is a logical type +**Status:** open +**Severity:** medium +**Issue:** PLT-1732 + +When a logical type (e.g. `UUID`, a dataclass) appears as the element type of a `list[T]` +or `dict[K, V]` annotation, `register_python_class` now raises `ValueError` at +schema-construction time rather than silently stripping the extension type. The underlying +cause is that PyArrow does not allow extension types inside list value fields or struct +fields (ET1): `pa.array([], type=pa.large_list(extension_type))` raises +`ArrowNotImplementedError: extension`. If a caller manually strips to storage types and +writes `large_list(large_binary)` for `list[UUID]`, the stored Arrow schema carries no +`orcapod.uuid` marker; on a fresh read `register_storage_type` finds nothing to register, +and value conversion with `storage_to_python(..., list[UUID])` fails unless `UUID` was +registered manually beforehand. + +**This does NOT affect logical types that are fields of a registered outer dataclass.** +Those are discovered and registered transitively: `register_discovered_extensions` finds +the outer dataclass extension type → `reconstruct_from_arrow` → `register_python_class` +per field annotation → inner type registered. The limitation applies only when the +outermost container (`list[T]`, `dict[K, V]`) is the top-level column type with no outer +dataclass wrapper. + +**Empirically confirmed** (2026-06-17): `pa.array([], type=pa.large_list(extension_type))` +raises `ArrowNotImplementedError: extension` — identical to the ET1 struct-field +restriction. The `replace_logical_type` flag approach (preserving extension type inside +list value field) is therefore infeasible at the PyArrow level. + +**Current behaviour:** `register_python_class(list[T])` raises `ValueError` when `T` +resolves to a logical type, pointing to this entry and PLT-1732. Use a direct `T` column +(no list wrapper) or wrap the list inside a dataclass field — the outer dataclass extension +type carries the annotation into the schema, and `reconstruct_from_arrow` re-registers `T` +transitively on read. + +**Planned fix (PLT-1732, target v0.2):** Introduce `ListLogicalType` / +`ListLogicalTypeFactory` and `StructLogicalType` / `StructLogicalTypeFactory`. A +`list[UUID]` top-level column would be wrapped as a new extension type +`orcapod.list[orcapod.uuid]` with storage `large_list(large_binary)`. The extension type +sits at the outermost (list) level, not inside the list value field, so it satisfies ET1. +`register_storage_type` would dispatch to the new factory on read, auto-registering the +element type. See PLT-1732 for full design. + +--- + +## `src/orcapod/databases/connector_arrow_database.py` + +### CA1 — SQL connectors silently lose Arrow extension-type field metadata on round-trip +**Status:** in progress +**Severity:** high +**Issue:** PLT-1795 + +`SQLiteConnector` (and any `DBConnectorProtocol` implementation that maps Arrow → SQL types) +does not preserve `ARROW:extension:name` / `ARROW:extension:metadata` field metadata. When a +column whose Arrow type is a `pa.ExtensionType` (e.g. `orcapod.path`, `orcapod.uuid`, or any +dataclass extension type) is written via `ConnectorArrowDatabase.add_records()` and then read +back, the column is returned as the raw storage type (e.g. `large_string`, `large_binary`, +`struct`) with no extension marker. This makes SQL connector round-trips impossible and causes silent data-type loss. + +**Interim fix (PLT-1659):** `ConnectorArrowDatabase.add_records()` now raises `ValueError` +immediately when any column is extension-typed, surfacing the issue at write +time rather than on a confusing read. Two representations are rejected: +- In-memory extension types: `isinstance(field.type, pa.ExtensionType)`. +- Metadata-only columns: plain storage type whose field metadata contains + `b"ARROW:extension:name"` (the representation produced when reading a Parquet/IPC file + with an unregistered extension type). + +**Full fix (PLT-1795, target v0.2):** Preserve extension-type metadata in the SQL schema via +a companion metadata table (one row per column: `table_name`, `column_name`, +`extension_name`, `extension_metadata`). On `create_table_if_not_exists`, write rows for any +extension-typed columns; on `iter_batches`, join the metadata table and reconstruct the +`pa.ExtensionType` for affected columns before returning the batch. Once implemented, the +`ValueError` guard in `add_records()` can be lifted. + +--- + ## `src/orcapod/semantic_types/universal_converter.py` ### UC1 — `python_type_to_arrow_type` raised on `typing.Any` from empty-container inference diff --git a/pyproject.toml b/pyproject.toml index 71e4d276..523fc7ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pandas>=2.2.3", "pyyaml>=6.0.2", "pyarrow>=20.0.0", - "polars>=1.31.0", + "polars>=1.36.0", "beartype>=0.21.0", "deltalake>=1.0.2", "graphviz>=0.21", @@ -27,6 +27,7 @@ dependencies = [ "s3fs>=2025.12.0", "pymongo>=4.15.5", "basedpyright>=1.38.1", + "pydantic>=2.0", ] readme = "README.md" requires-python = ">=3.11.0" diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index 9d30caa7..aff39341 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -11,8 +11,6 @@ ) from .core.nodes.source_node import SourceNode from .pipeline import Pipeline, PipelineJob -from .semantic_types.dataclass_encoding import register_dataclass - # Subpackage re-exports for clean public API from . import databases # noqa: F401 from . import nodes # noqa: F401 @@ -21,6 +19,18 @@ from . import streams # noqa: F401 from . import types # noqa: F401 +# Stable type aliases — preferred over importing directly from pathlib/upath/uuid. +# +# These aliases are the recommended way to reference these types in orcapod user code. +# Even if an upstream library is renamed or restructured, these symbols remain stable +# at ``orcapod.Path``, ``orcapod.UPath``, and ``orcapod.UUID``. Their Arrow extension +# types are registered under the ``orcapod.*`` namespace (``"orcapod.path"``, +# ``"orcapod.upath"``, ``"orcapod.uuid"``), so on-disk identity is also decoupled +# from upstream module paths. +from pathlib import Path +from upath import UPath +from uuid import UUID + __all__ = [ "DEFAULT_CONFIG", "DisplayConfig", @@ -32,13 +42,16 @@ "Pipeline", "PipelineJob", "SourceNode", - "register_dataclass", "databases", "nodes", "operators", "sources", "streams", "types", + # Stable type aliases + "Path", + "UPath", + "UUID", ] diff --git a/src/orcapod/contexts/core.py b/src/orcapod/contexts/core.py index cd6b1cf5..d84ae67f 100644 --- a/src/orcapod/contexts/core.py +++ b/src/orcapod/contexts/core.py @@ -1,13 +1,7 @@ -""" -Core data structures and exceptions for the OrcaPod context system. - -This module defines the basic types and exceptions used throughout -the context management system. -""" +"""Core data structures and exceptions for the OrcaPod context system.""" from dataclasses import dataclass -from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry from orcapod.protocols.hashing_protocols import ( ArrowHasherProtocol, SemanticHasherProtocol, @@ -17,21 +11,18 @@ @dataclass class DataContext: - """ - Data context containing all versioned components needed for data interpretation. - - A DataContext represents a specific version of the OrcaPod system configuration, - including semantic type registries, hashers, and other components that affect - how data is processed and interpreted. + """Data context containing all versioned components needed for data interpretation. Attributes: context_key: Unique identifier (e.g., "std:v0.1:default") version: Version string (e.g., "v0.1") - description: Human-readable description of this context - semantic_type_registry: Registry of semantic type converters + description: Human-readable description + type_converter: Type converter for Python ↔ Arrow conversion and + registration. This is the single public API for all type operations. arrow_hasher: Arrow table hasher for this context - semantic_hasher: General semantic hasher for this context - type_handler_registry: Registry of TypeHandlerProtocol instances for SemanticHasherProtocol + semantic_hasher: General semantic hasher for this context. The + ``PythonTypeHandlerRegistry`` used for hashing is accessible via + ``semantic_hasher.type_handler_registry``. """ context_key: str @@ -39,9 +30,7 @@ class DataContext: description: str type_converter: TypeConverterProtocol arrow_hasher: ArrowHasherProtocol - semantic_hasher: SemanticHasherProtocol # this is the currently the JSON hasher - type_handler_registry: TypeHandlerRegistry - + semantic_hasher: SemanticHasherProtocol class ContextValidationError(Exception): """Raised when context validation fails.""" diff --git a/src/orcapod/contexts/data/schemas/context_schema.json b/src/orcapod/contexts/data/schemas/context_schema.json index b2380124..366ce12f 100644 --- a/src/orcapod/contexts/data/schemas/context_schema.json +++ b/src/orcapod/contexts/data/schemas/context_schema.json @@ -8,11 +8,9 @@ "required": [ "context_key", "version", - "semantic_registry", "type_converter", "arrow_hasher", - "semantic_hasher", - "type_handler_registry" + "semantic_hasher" ], "properties": { "context_key": { @@ -43,10 +41,6 @@ "Enhanced version with timestamp support and improved hashing" ] }, - "semantic_registry": { - "$ref": "#/$defs/objectspec", - "description": "ObjectSpec for the semantic registry" - }, "type_converter": { "$ref": "#/$defs/objectspec", "description": "ObjectSpec for the python-arrow type converter" @@ -59,17 +53,17 @@ "$ref": "#/$defs/objectspec", "description": "ObjectSpec for the semantic hasher component" }, - "type_handler_registry": { + "python_type_handler_registry": { "$ref": "#/$defs/objectspec", - "description": "ObjectSpec for the TypeHandlerRegistry used by the semantic hasher" + "description": "ObjectSpec for the PythonTypeHandlerRegistry used by the semantic hasher" }, "file_hasher": { "$ref": "#/$defs/objectspec", - "description": "ObjectSpec for the file content hasher (used by PathContentHandler)" + "description": "ObjectSpec for the file content hasher (used by PathHandler)" }, - "function_info_extractor": { + "function_semantic_hasher": { "$ref": "#/$defs/objectspec", - "description": "ObjectSpec for the function info extractor (used by FunctionHandler)" + "description": "ObjectSpec for the function semantic hasher (used by FunctionHandler)" }, "metadata": { "type": "object", @@ -169,51 +163,32 @@ { "context_key": "std:v0.1:default", "version": "v0.1", - "description": "Initial stable release with basic Path semantic type support", - "semantic_type_registry": { - "_class": "orcapod.types.semantic_types.SemanticTypeRegistry", - "_config": { - "converters": [ - { - "_class": "orcapod.types.semantic_types.PythonPathStructConverter", - "_config": {} - } - ] - } + "description": "Initial stable release with extension type hashing support", + "type_converter": { + "_class": "orcapod.semantic_types.universal_converter.UniversalTypeConverter", + "_config": {} }, "arrow_hasher": { - "_class": "orcapod.hashing.arrow_hashers.SemanticArrowHasher", + "_class": "orcapod.hashing.arrow_hashers.StarfixArrowHasher", "_config": { "hasher_id": "arrow_v0.1", - "hash_algorithm": "sha256", - "serialization_method": "logical", - "semantic_type_hashers": { - "path": { - "_class": "orcapod.hashing.semantic_type_hashers.PathHasher", - "_config": { - "file_hasher": { - "_class": "orcapod.hashing.file_hashers.BasicFileHasher", - "_config": { - "algorithm": "sha256" - } - } - } - } - } + "type_converter": {"_ref": "type_converter"}, + "semantic_hasher": {"_ref": "semantic_hasher"} } }, "semantic_hasher": { - "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher", + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.SemanticAwarePythonHasher", "_config": { - "hasher_id": "semantic_v0.1" + "hasher_id": "semantic_v0.1", + "type_handler_registry": {"_ref": "python_type_handler_registry"} } }, "metadata": { - "created_date": "2025-08-01", + "created_date": "2026-06-24", "author": "OrcaPod Team", "changelog": [ - "Initial release with semantic type registry", - "Basic Arrow and object hashing capabilities" + "Initial release with extension type hashing support", + "StarfixArrowHasher for cross-language-compatible Arrow hashing" ] } } diff --git a/src/orcapod/contexts/data/v0.1.json b/src/orcapod/contexts/data/v0.1.json index 2fb31a70..75da5243 100644 --- a/src/orcapod/contexts/data/v0.1.json +++ b/src/orcapod/contexts/data/v0.1.json @@ -1,95 +1,110 @@ { "context_key": "std:v0.1:default", "version": "v0.1", - "description": "Initial stable release with basic Path semantic type support", + "description": "Initial stable release with extension type hashing support", "file_hasher": { "_class": "orcapod.hashing.file_hashers.BasicFileHasher", "_config": { "algorithm": "sha256" } }, - "semantic_registry": { - "_class": "orcapod.semantic_types.semantic_registry.SemanticTypeRegistry", - "_config": { - "converters": { - "upath": { - "_class": "orcapod.semantic_types.semantic_struct_converters.UPathStructConverter", - "_config": { - "file_hasher": {"_ref": "file_hasher"} - } - }, - "path": { - "_class": "orcapod.semantic_types.semantic_struct_converters.PythonPathStructConverter", - "_config": { - "file_hasher": {"_ref": "file_hasher"} - } - } - } - } - }, - "arrow_hasher": { - "_class": "orcapod.hashing.arrow_hashers.StarfixArrowHasher", - "_config": { - "hasher_id": "arrow_v0.1", - "semantic_registry": { - "_ref": "semantic_registry" - } - } - }, "type_converter": { "_class": "orcapod.semantic_types.universal_converter.UniversalTypeConverter", "_config": { - "semantic_registry": { - "_ref": "semantic_registry" + "logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ], + "factories": [ + { + "factory": { + "_class": "orcapod.extension_types.dataclass_logical_type_factory.DataclassLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.dataclass", + "python_bases": [{"_type": "builtins.object"}] + }, + { + "factory": { + "_class": "orcapod.extension_types.pydantic_logical_type_factory.PydanticLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.pydantic", + "python_bases": [{"_type": "pydantic.BaseModel"}] + } + ] + } } } }, - "function_info_extractor": { + "function_semantic_hasher": { "_class": "orcapod.hashing.semantic_hashing.function_info_extractors.FunctionSignatureExtractor", "_config": { "include_module": true, "include_defaults": true } }, - "type_handler_registry": { - "_class": "orcapod.hashing.semantic_hashing.type_handler_registry.TypeHandlerRegistry", + "python_type_handler_registry": { + "_class": "orcapod.hashing.semantic_hashing.type_handler_registry.PythonTypeHandlerRegistry", "_config": { "handlers": [ - [{"_type": "builtins.bytes"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], - [{"_type": "builtins.bytearray"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], - [{"_type": "pathlib.Path"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.PathContentHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], - [{"_type": "upath.core.UPath"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UPathContentHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], - [{"_type": "uuid.UUID"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UUIDHandler", "_config": {}}], - [{"_type": "types.FunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], - [{"_type": "types.BuiltinFunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], - [{"_type": "types.MethodType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], - [{"_type": "builtins.type"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.TypeObjectHandler", "_config": {}}], - [{"_type": "types.GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], - [{"_type": "types.UnionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UnionTypeHandler", "_config": {}}], - [{"_type": "typing._GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], - [{"_type": "typing._SpecialForm"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.SpecialFormHandler", "_config": {}}], - [{"_type": "pyarrow.Table"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}], - [{"_type": "pyarrow.RecordBatch"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}] + [{"_type": "builtins.bytes"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "builtins.bytearray"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "pathlib.Path"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.PathHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "upath.core.UPath"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UPathHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "uuid.UUID"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UUIDHandler", "_config": {}}], + [{"_type": "types.FunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_semantic_hasher"}}}], + [{"_type": "types.BuiltinFunctionType"},{"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_semantic_hasher"}}}], + [{"_type": "types.MethodType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_semantic_hasher"}}}], + [{"_type": "builtins.type"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.TypeObjectHandler", "_config": {}}], + [{"_type": "types.GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "types.UnionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UnionTypeHandler", "_config": {}}], + [{"_type": "typing._GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "typing._SpecialForm"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.SpecialFormHandler", "_config": {}}], + [{"_type": "pyarrow.Table"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {}}], + [{"_type": "pyarrow.RecordBatch"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {}}] ] } }, "semantic_hasher": { - "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher", + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.SemanticAwarePythonHasher", "_config": { "hasher_id": "semantic_v0.1", "type_handler_registry": { - "_ref": "type_handler_registry" + "_ref": "python_type_handler_registry" } } }, + "arrow_hasher": { + "_class": "orcapod.hashing.arrow_hashers.StarfixArrowHasher", + "_config": { + "hasher_id": "arrow_v0.1", + "type_converter": {"_ref": "type_converter"}, + "semantic_hasher": {"_ref": "semantic_hasher"} + } + }, "metadata": { - "created_date": "2025-08-01", + "created_date": "2026-06-24", "author": "OrcaPod Core Team", "changelog": [ "Initial release with Path semantic type support", "Basic SHA-256 hashing for files and objects", "Arrow logical serialization method", - "Introduced arrow_v0.1 StarfixArrowHasher using starfix ArrowDigester for cross-language-compatible Arrow hashing" + "Introduced arrow_v0.1 StarfixArrowHasher using starfix ArrowDigester for cross-language-compatible Arrow hashing", + "Hard cut: replaced shape-based SemanticTypeRegistry with extension-type hashing; renamed all hashing classes to cleaner names" ] } } diff --git a/src/orcapod/contexts/registry.py b/src/orcapod/contexts/registry.py index 80182ac3..9607ed34 100644 --- a/src/orcapod/contexts/registry.py +++ b/src/orcapod/contexts/registry.py @@ -151,7 +151,6 @@ def _load_spec_file(self, json_file: Path) -> None: "type_converter", "arrow_hasher", "semantic_hasher", - "type_handler_registry", ] missing_fields = [field for field in required_fields if field not in spec] if missing_fields: @@ -300,7 +299,6 @@ def _create_context_from_spec(self, spec: dict[str, Any]) -> DataContext: type_converter=ref_lut["type_converter"], arrow_hasher=ref_lut["arrow_hasher"], semantic_hasher=ref_lut["semantic_hasher"], - type_handler_registry=ref_lut["type_handler_registry"], ) except Exception as e: diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index ebc25f69..be2fee48 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -56,6 +56,7 @@ def _executor_supports_concurrent( return executor is not None and executor.supports_concurrent_execution + class _FunctionPodBase(TraceableBase): """Base pod that applies a data function to each input data.""" @@ -74,6 +75,10 @@ def __init__( ) self.tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER self._data_function = data_function + self.data_context.type_converter.ensure_types_registered_for_schemas( + data_function.input_data_schema, + data_function.output_data_schema, + ) def computed_label(self) -> str | None: """Use the data function's canonical name as the default label.""" @@ -651,7 +656,7 @@ def as_table( return output_table -class CallableWithPod(Protocol): +class CallableWithPodProtocol(Protocol): @property def pod(self) -> _FunctionPodBase: """Return the associated function pod.""" @@ -671,7 +676,7 @@ def function_pod( pod_cache_database: ArrowDatabaseProtocol | None = None, executor: DataFunctionExecutorProtocol | None = None, **kwargs, -) -> Callable[..., CallableWithPod]: +) -> Callable[..., CallableWithPodProtocol]: """Decorator that attaches a ``FunctionPod`` as a ``pod`` attribute. Args: @@ -691,7 +696,7 @@ def function_pod( A decorator that adds a ``pod`` attribute to the wrapped function. """ - def decorator(func: Callable) -> CallableWithPod: + def decorator(func: Callable) -> CallableWithPodProtocol: if func.__name__ == "": raise ValueError("Lambda functions cannot be used with function_pod") @@ -731,7 +736,7 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) setattr(wrapper, "pod", pod) - return cast(CallableWithPod, wrapper) + return cast(CallableWithPodProtocol, wrapper) return decorator diff --git a/src/orcapod/databases/__init__.py b/src/orcapod/databases/__init__.py index 8a393dd5..864aecc3 100644 --- a/src/orcapod/databases/__init__.py +++ b/src/orcapod/databases/__init__.py @@ -1,5 +1,6 @@ from .connector_arrow_database import ConnectorArrowDatabase from .delta_lake_databases import DeltaTableDatabase +from .extension_aware_database import ExtensionAwareDatabase from .in_memory_databases import InMemoryArrowDatabase from .noop_database import NoOpArrowDatabase from .spiraldb_connector import SpiralDBConnector @@ -9,6 +10,7 @@ __all__ = [ "ConnectorArrowDatabase", "DeltaTableDatabase", + "ExtensionAwareDatabase", "InMemoryArrowDatabase", "NoOpArrowDatabase", "SpiralDBConnector", diff --git a/src/orcapod/databases/connector_arrow_database.py b/src/orcapod/databases/connector_arrow_database.py index ab6928ed..6e289c5a 100644 --- a/src/orcapod/databases/connector_arrow_database.py +++ b/src/orcapod/databases/connector_arrow_database.py @@ -244,6 +244,34 @@ def add_records( f"got {rid_type}. Encode the column to bytes before calling add_records()." ) + # Reject Arrow extension-typed columns: SQL connectors do not preserve + # ARROW:extension:* field metadata, so extension types would be silently + # dropped on read, making round-trips impossible. Use DeltaTableDatabase + # or write directly to Parquet instead. See PLT-1795 for the planned fix. + # + # Two representations are checked: + # 1. In-memory extension types: isinstance(field.type, pa.ExtensionType). + # 2. Metadata-only extension columns: a plain Arrow type whose field metadata + # contains the b"ARROW:extension:name" key. This arises when reading a + # Parquet/IPC file with an unregistered extension type — the array is + # decoded as its storage type but the metadata is preserved on the field. + _EXT_NAME_KEY = b"ARROW:extension:name" + ext_fields: list[tuple[str, str]] = [] + for field in records.schema: + if isinstance(field.type, pa.ExtensionType): + ext_fields.append((field.name, field.type.extension_name)) + elif field.metadata and _EXT_NAME_KEY in field.metadata: + ext_fields.append((field.name, field.metadata[_EXT_NAME_KEY].decode("utf-8", errors="replace"))) + if ext_fields: + ext_info = ", ".join(f"{name!r}: {ext_name!r}" for name, ext_name in ext_fields) + raise ValueError( + f"ConnectorArrowDatabase does not support Arrow extension-typed columns " + f"({ext_info}). SQL connectors do not preserve ARROW:extension:* field " + f"metadata, so extension types would be silently dropped on read. " + f"Use DeltaTableDatabase or write directly to Parquet instead. " + f"See PLT-1795 for the planned fix." + ) + records = self._deduplicate_within_table(records) record_key = self._get_record_key(record_path) input_ids = set(cast(list[bytes], records[self.RECORD_ID_COLUMN].to_pylist())) diff --git a/src/orcapod/databases/extension_aware_database.py b/src/orcapod/databases/extension_aware_database.py new file mode 100644 index 00000000..93a86bf2 --- /dev/null +++ b/src/orcapod/databases/extension_aware_database.py @@ -0,0 +1,183 @@ +"""ExtensionAwareDatabase — ArrowDatabaseProtocol wrapper that handles extension type registration. + +Wraps any ``ArrowDatabaseProtocol`` backend and transparently applies the +register → cast pattern on every read result: + +1. Call ``register_discovered_extensions(converter, table.schema)`` to ensure + all Arrow extension types found in the returned table's field metadata are + registered with the converter. +2. Call ``converter.apply_extension_types(table)`` to re-wrap columns that + were loaded as plain storage types into their correct extension types. + This operation is zero-copy (``pa.ExtensionArray.from_storage`` per chunk). + +Write operations pass through to the underlying database unchanged. + +Example:: + + db = DeltaTableDatabase("/path/to/store") + ext_db = ExtensionAwareDatabase(db, converter=type_converter) + table = ext_db.get_all_records(("results", "my_fn")) + # table columns have proper extension types applied +""" +from __future__ import annotations + +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.database_hooks import register_discovered_extensions +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol + + +class ExtensionAwareDatabase: + """``ArrowDatabaseProtocol`` wrapper that auto-registers and applies extension types. + + All read methods delegate to the wrapped *db*, then: + + 1. Walk the returned table's schema to find any extension types (from + preserved ``ARROW:extension:*`` field metadata). + 2. Register any newly discovered types with *converter* via + ``register_discovered_extensions``. + 3. Re-wrap columns that were loaded as plain storage types into their + correct Arrow extension types via ``converter.apply_extension_types`` + (zero-copy). + + Write methods and ``flush`` delegate directly without modification. + + Args: + db: Any ``ArrowDatabaseProtocol`` backend. + converter: The ``TypeConverterProtocol`` to use for registration and + lookup. + """ + + def __init__( + self, + db: ArrowDatabaseProtocol, + converter: TypeConverterProtocol, + ) -> None: + self._db = db + self._converter = converter + + # ── Internal helper ─────────────────────────────────────────────────────── + + def _process(self, table: pa.Table | None) -> pa.Table | None: + """Register extension types and re-wrap columns, or return None unchanged.""" + if table is None: + return None + register_discovered_extensions(self._converter, table.schema) + return self._converter.apply_extension_types(table) + + # ── Read methods ────────────────────────────────────────────────────────── + + def get_record_by_id( + self, + record_path: tuple[str, ...], + record_id: bytes, + record_id_column: str | None = None, + flush: bool = False, + ) -> pa.Table | None: + return self._process( + self._db.get_record_by_id( + record_path, + record_id, + record_id_column=record_id_column, + flush=flush, + ) + ) + + def get_all_records( + self, + record_path: tuple[str, ...], + record_id_column: str | None = None, + ) -> pa.Table | None: + return self._process( + self._db.get_all_records(record_path, record_id_column=record_id_column) + ) + + def get_records_by_ids( + self, + record_path: tuple[str, ...], + record_ids: Collection[bytes], + record_id_column: str | None = None, + flush: bool = False, + ) -> pa.Table | None: + return self._process( + self._db.get_records_by_ids( + record_path, + record_ids, + record_id_column=record_id_column, + flush=flush, + ) + ) + + def get_records_with_column_value( + self, + record_path: tuple[str, ...], + column_values: Collection[tuple[str, Any]] | Mapping[str, Any], + record_id_column: str | None = None, + flush: bool = False, + ) -> pa.Table | None: + return self._process( + self._db.get_records_with_column_value( + record_path, + column_values, + record_id_column=record_id_column, + flush=flush, + ) + ) + + # ── Write methods (pass-through) ────────────────────────────────────────── + + def add_record( + self, + record_path: tuple[str, ...], + record_id: bytes, + record: pa.Table, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + self._db.add_record( + record_path, + record_id, + record, + skip_duplicates=skip_duplicates, + flush=flush, + ) + + def add_records( + self, + record_path: tuple[str, ...], + records: pa.Table, + record_id_column: str | None = None, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + self._db.add_records( + record_path, + records, + record_id_column=record_id_column, + skip_duplicates=skip_duplicates, + flush=flush, + ) + + def flush(self) -> None: + self._db.flush() + + # ── Structural delegation ───────────────────────────────────────────────── + + @property + def base_path(self) -> tuple[str, ...]: + return self._db.base_path + + def at(self, *path_components: str) -> ExtensionAwareDatabase: + """Return a scoped view, preserving the extension-aware wrapper.""" + return ExtensionAwareDatabase( + self._db.at(*path_components), + converter=self._converter, + ) + + def to_config(self) -> dict[str, Any]: + return self._db.to_config() diff --git a/src/orcapod/extension_types/__init__.py b/src/orcapod/extension_types/__init__.py index e69de29b..bdbd4c29 100644 --- a/src/orcapod/extension_types/__init__.py +++ b/src/orcapod/extension_types/__init__.py @@ -0,0 +1,50 @@ +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that map +between Python objects and their Arrow/Polars extension type representation. + +Built-in registrations (``LogicalPath``, ``LogicalUPath``, ``LogicalUUID``) are +wired into ``DataContext`` via ``contexts/data/v0.1.json``. Use +``get_default_context().type_converter.register_python_class()`` to register new +types, ``register_logical_type_factory()`` to add factories, and +``apply_extension_types()`` to re-wrap Arrow tables with their registered extension types. + +``DataclassLogicalTypeFactory`` provides automatic registration for Python dataclasses: +register it with a ``LogicalTypeRegistry`` and any dataclass used in a ``FunctionPod`` +will be auto-registered on pod declaration. + +``PydanticLogicalTypeFactory`` provides automatic registration for pydantic v2 +``BaseModel`` subclasses. Requires the optional ``pydantic`` extra. +""" + +from __future__ import annotations + +from .protocols import LogicalTypeProtocol, LogicalTypeFactoryProtocol +from .registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema +from .database_hooks import apply_extension_types, register_discovered_extensions +from .dataclass_logical_type_factory import DATACLASS_CATEGORY, DataclassLogicalType, DataclassLogicalTypeFactory +from .pydantic_logical_type_factory import PYDANTIC_CATEGORY, PydanticLogicalType, PydanticLogicalTypeFactory + +__all__ = [ + "LogicalTypeProtocol", + "LogicalTypeFactoryProtocol", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "make_polars_extension_type", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", + # PLT-1655 + "register_discovered_extensions", + "apply_extension_types", + # PLT-1705 + "DATACLASS_CATEGORY", + "DataclassLogicalType", + "DataclassLogicalTypeFactory", + # PLT-1731 + "PYDANTIC_CATEGORY", + "PydanticLogicalType", + "PydanticLogicalTypeFactory", +] diff --git a/src/orcapod/extension_types/builtin_logical_types.py b/src/orcapod/extension_types/builtin_logical_types.py new file mode 100644 index 00000000..7bfc23e9 --- /dev/null +++ b/src/orcapod/extension_types/builtin_logical_types.py @@ -0,0 +1,263 @@ +"""Built-in LogicalType implementations for orcapod. + +Provides three built-in logical types registered into the default +``DataContext.logical_type_registry`` via ``contexts/data/v0.1.json``: + +- ``LogicalPath``: maps ``pathlib.Path`` ↔ Arrow large_string extension ``"orcapod.path"`` +- ``LogicalUPath``: maps ``upath.UPath`` ↔ Arrow large_string extension ``"orcapod.upath"`` +- ``LogicalUUID``: maps ``uuid.UUID`` ↔ Arrow large_binary extension ``"orcapod.uuid"`` + +All three types use the ``orcapod.*`` extension name namespace rather than the upstream +module-qualified names (``"pathlib.Path"``, etc.). This gives Orcapod stable ownership of +the on-disk extension identity: even if the upstream library is renamed or restructured, +data written with these extension names continues to be readable without modification. + +Note: + All imports from orcapod.extension_types use direct submodule paths + (e.g. ``from orcapod.extension_types.registry import ...``) rather than + the package ``__init__`` to avoid circular imports when the context system + loads this module at startup. +""" + +from __future__ import annotations + +import pathlib +import uuid as _uuid_module +from typing import TYPE_CHECKING, Any + +import polars as pl +import pyarrow as pa +from upath import UPath + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type + +if TYPE_CHECKING: + from orcapod.extension_types.protocols import TypeConverterProtocol + + +class LogicalPath: + """Logical type for ``pathlib.Path``. + + Stores paths as Arrow large strings using the custom extension type + ``"orcapod.path"``. + + The extension name ``"orcapod.path"`` is Orcapod-owned and stable; it does not + depend on the upstream ``pathlib`` module path. Use ``orcapod.Path`` (a top-level + alias for ``pathlib.Path``) as the preferred way to reference this type in user code. + + Example: + >>> lt = LogicalPath() + >>> lt.python_to_storage(pathlib.Path("/tmp/foo")) + '/tmp/foo' + >>> lt.storage_to_python('/tmp/foo') + PosixPath('/tmp/foo') + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.path", pa.large_string()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.path", pa.large_string()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.path" + python_type: type = pathlib.Path + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``pathlib.Path``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.path"`` and storage type ``pa.large_string()``. + """ + if LogicalPath._arrow_ext is None: + LogicalPath._arrow_ext = LogicalPath._arrow_ext_class() + return LogicalPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``pathlib.Path``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.path"``. + """ + if LogicalPath._polars_ext is None: + LogicalPath._polars_ext = LogicalPath._polars_ext_class() + return LogicalPath._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None = None) -> str: + """Convert a ``pathlib.Path`` to its string representation. + + Args: + value: A ``pathlib.Path`` instance. + converter: Ignored. Present for protocol conformance. + + Returns: + The string form of the path (e.g. ``"/tmp/foo"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None = None) -> pathlib.Path: + """Reconstruct a ``pathlib.Path`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + converter: Ignored. Present for protocol conformance. + + Returns: + A ``pathlib.Path`` instance. + """ + return pathlib.Path(storage_value) + + +class LogicalUPath: + """Logical type for ``upath.UPath``. + + Stores paths as Arrow large strings using the custom extension type + ``"orcapod.upath"``. + + The extension name ``"orcapod.upath"`` is Orcapod-owned and stable; it does not + depend on the upstream ``upath`` module path. Use ``orcapod.UPath`` (a top-level + alias for ``upath.UPath``) as the preferred way to reference this type in user code. + + Example: + >>> lt = LogicalUPath() + >>> lt.python_to_storage(UPath("s3://bucket/key")) + 's3://bucket/key' + >>> lt.storage_to_python("s3://bucket/key") + UPath('s3://bucket/key') + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.upath", pa.large_string()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.upath", pa.large_string()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.upath" + python_type: type = UPath + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``upath.UPath``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.upath"`` and storage type ``pa.large_string()``. + """ + if LogicalUPath._arrow_ext is None: + LogicalUPath._arrow_ext = LogicalUPath._arrow_ext_class() + return LogicalUPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``upath.UPath``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.upath"``. + """ + if LogicalUPath._polars_ext is None: + LogicalUPath._polars_ext = LogicalUPath._polars_ext_class() + return LogicalUPath._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None = None) -> str: + """Convert a ``upath.UPath`` to its string representation. + + Args: + value: A ``upath.UPath`` instance. + converter: Ignored. Present for protocol conformance. + + Returns: + The string form of the path (e.g. ``"s3://bucket/key"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None = None) -> UPath: + """Reconstruct a ``upath.UPath`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + converter: Ignored. Present for protocol conformance. + + Returns: + A ``upath.UPath`` instance. + """ + return UPath(storage_value) + + +class LogicalUUID: + """Logical type for ``uuid.UUID``. + + Stores UUIDs as Arrow binary (16 bytes) using the custom extension type + ``"orcapod.uuid"``. Both the Arrow extension name and ``logical_type_name`` + are ``"orcapod.uuid"``, consistent with ``LogicalPath`` and ``LogicalUPath``. + + The extension name ``"orcapod.uuid"`` is Orcapod-owned and stable, replacing + the previous ``"uuid.UUID"`` name that mirrored PyArrow's ``"arrow.uuid"`` + territory. Use ``orcapod.UUID`` (a top-level alias for ``uuid.UUID``) as the + preferred way to reference this type in user code. + + The storage type is ``pa.large_binary()`` (variable-length binary), using + big-endian byte order as returned by ``uuid.UUID.bytes``. ``large_binary`` + is used rather than ``pa.binary(16)`` (fixed-size) because Polars maps + fixed-size binary to variable-length on the round-trip, which would + conflict with the deserializer's storage type check. + + Example: + >>> import uuid + >>> lt = LogicalUUID() + >>> u = uuid.uuid4() + >>> lt.storage_to_python(lt.python_to_storage(u)) == u + True + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.uuid", pa.large_binary()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.uuid", pa.large_binary()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.uuid" + python_type: type = _uuid_module.UUID + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``uuid.UUID``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.uuid"`` and storage type ``pa.large_binary()``. + """ + if LogicalUUID._arrow_ext is None: + LogicalUUID._arrow_ext = LogicalUUID._arrow_ext_class() + return LogicalUUID._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``uuid.UUID``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.uuid"``. + """ + if LogicalUUID._polars_ext is None: + LogicalUUID._polars_ext = LogicalUUID._polars_ext_class() + return LogicalUUID._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None = None) -> bytes: + """Convert a ``uuid.UUID`` to its 16-byte binary representation. + + Args: + value: A ``uuid.UUID`` instance. + converter: Ignored. Present for protocol conformance. + + Returns: + A 16-byte ``bytes`` object (big-endian byte order, as per + ``uuid.UUID.bytes``). + """ + return value.bytes + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None = None) -> _uuid_module.UUID: + """Reconstruct a ``uuid.UUID`` from its 16-byte binary representation. + + Args: + storage_value: A bytes-like object of length 16. + converter: Ignored. Present for protocol conformance. + + Returns: + A ``uuid.UUID`` instance. + """ + return _uuid_module.UUID(bytes=bytes(storage_value)) diff --git a/src/orcapod/extension_types/database_hooks.py b/src/orcapod/extension_types/database_hooks.py new file mode 100644 index 00000000..95abb65f --- /dev/null +++ b/src/orcapod/extension_types/database_hooks.py @@ -0,0 +1,244 @@ +"""Schema-walking utilities for extension type auto-registration and post-load casting. + +Two entry points: + +``register_discovered_extensions(converter, schema)`` + Walk an Arrow schema and register any extension types not yet known to + *converter*. No-op when *converter* is ``None`` or the schema has no + extension types. + +``apply_extension_types(table, registry)`` + Re-wrap columns of *table* that carry ``ARROW:extension:*`` field metadata + into their registered extension types. Operates per-chunk so no data is + copied — each chunk is wrapped with ``pa.ExtensionArray.from_storage()``. + Nested struct fields are reconstructed recursively. + +These two functions are typically called in sequence via ``UniversalTypeConverter``: + + register_discovered_extensions(converter, table.schema) + table = converter.apply_extension_types(table) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from orcapod.extension_types.registry import LogicalTypeRegistry +from orcapod.extension_types.schema_walker import walk_schema + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol + +logger = logging.getLogger(__name__) + + +def register_discovered_extensions( + converter: "TypeConverterProtocol | None", + schema: "pa.Schema", +) -> None: + """Register any extension types found in ``schema`` that are not yet known. + + Walks ``schema`` recursively via ``walk_schema`` to discover all Arrow extension + types at any nesting depth (both in-memory and field-metadata channels). + For each discovered type, delegates to ``converter.register_arrow_extension``. + + Already-registered types are detected and skipped inside the converter — + this function itself is stateless beyond the converter it operates on. + + Args: + converter: The ``TypeConverterProtocol`` to use for registration. + If ``None``, this call is a no-op. + schema: The Arrow schema to inspect. May contain no extension types, + in which case this call is a no-op. + + Raises: + ValueError: Propagated from the converter if an extension type's metadata + has no registered factory or is malformed. + """ + if converter is None: + logger.debug("register_discovered_extensions: no converter provided, skipping") + return + + found = walk_schema(schema) + if not found: + logger.debug("register_discovered_extensions: no extension types in schema") + return + logger.debug( + "register_discovered_extensions: found %d extension type(s) in schema: %s", + len(found), + [info.extension_name for info in found], + ) + for info in found: + # Bottom-up resolve the storage type first, then register the extension + resolved_storage = converter.register_storage_type(info.storage_type) + converter.register_arrow_extension( + info.extension_name, + info.extension_metadata, + resolved_storage, + ) + + +def apply_extension_types( + table: pa.Table, + registry: LogicalTypeRegistry, +) -> pa.Table: + """Re-wrap *table* columns into their registered Arrow extension types. + + Arrow preserves ``ARROW:extension:name`` / ``ARROW:extension:metadata`` + field metadata even when an extension type was not registered at read + time, in which case the column is stored as a plain storage type (e.g. + ``large_utf8``). Once the extension type has been registered (via + ``register_discovered_extensions``), this function reconstructs the + correct extension-typed columns using ``pa.ExtensionArray.from_storage``. + + The operation is zero-copy per chunk: each chunk in a ``ChunkedArray`` + is individually wrapped without rechunking or data movement. Struct + columns are handled recursively so nested extension type fields are also + reconstructed. + + Columns whose field has no ``ARROW:extension:name`` metadata (plain Arrow + types) are left untouched. + + Args: + table: Arrow table whose columns may contain extension type metadata + but were loaded as storage types. + registry: Registry that holds the registered ``LogicalTypeProtocol`` + instances. Must already contain every extension type referenced + by ``table.schema`` — call ``register_discovered_extensions`` + first. + + Returns: + A new ``pa.Table`` with extension-typed columns re-wrapped. Columns + with no extension type metadata are shared with *table* unchanged. + """ + import pyarrow as pa + + new_columns: list[pa.ChunkedArray] = [] + new_fields: list[pa.Field] = [] + changed = False + + for i, field in enumerate(table.schema): + col = table.column(i) + new_col, new_field = _apply_field(col, field, registry) + new_columns.append(new_col) + new_fields.append(new_field) + if new_field is not field: + changed = True + + if not changed: + return table + + # Preserve any schema-level metadata (e.g. pandas metadata) from the original. + new_schema = pa.schema(new_fields, metadata=table.schema.metadata) + return pa.table(dict(zip(new_schema.names, new_columns)), schema=new_schema) + + +def _apply_field( + col: pa.ChunkedArray, + field: pa.Field, + registry: LogicalTypeRegistry, +) -> tuple[pa.ChunkedArray, pa.Field]: + """Return *(new_col, new_field)* with extension type applied if needed. + + Handles three cases: + - Field already has an extension type → return as-is. + - Field has extension metadata and a registered type → wrap per-chunk. + - Field is a struct with extension-typed children → recurse. + """ + import pyarrow as pa + + field_meta = field.metadata or {} + ext_name_bytes = field_meta.get(b"ARROW:extension:name") + + # ── Case 1: field is already an extension type (registered at read time) ── + if hasattr(field.type, "extension_name"): + return col, field + + # ── Case 2: field has extension metadata and a matching registered type ─── + if ext_name_bytes is not None: + ext_name = ext_name_bytes.decode("utf-8") + lt = registry.get_by_arrow_extension_name(ext_name) + if lt is not None: + ext_type = lt.get_arrow_extension_type() + wrapped_chunks = [ + pa.ExtensionArray.from_storage(ext_type, chunk) + for chunk in col.chunks + ] + new_col = pa.chunked_array(wrapped_chunks, type=ext_type) + new_field = field.with_type(ext_type) + logger.debug("apply_extension_types: wrapped column %r as %r", field.name, ext_name) + return new_col, new_field + + # ── Case 3: struct — recurse only if children carry extension metadata ────── + if pa.types.is_struct(field.type): + if _has_nested_extension_fields(field.type): + return _apply_struct_field(col, field, registry) + + return col, field + + +def _has_nested_extension_fields(arrow_type: pa.DataType) -> bool: + """Return True if any child field at any nesting depth carries extension metadata. + + Used to guard struct recursion: structs whose children carry no + ``ARROW:extension:name`` metadata are returned as-is without rebuilding. + """ + import pyarrow as pa + + for i in range(arrow_type.num_fields): + child = arrow_type.field(i) + meta = child.metadata or {} + if b"ARROW:extension:name" in meta: + return True + if pa.types.is_struct(child.type) and _has_nested_extension_fields(child.type): + return True + return False + + +def _apply_struct_field( + col: pa.ChunkedArray, + field: pa.Field, + registry: LogicalTypeRegistry, +) -> tuple[pa.ChunkedArray, pa.Field]: + """Recursively apply extension types to children of a struct column.""" + import pyarrow as pa + + struct_type = field.type + child_fields = [struct_type.field(i) for i in range(struct_type.num_fields)] + + # Process each chunk: rebuild StructArray with re-wrapped children. + new_chunks: list[pa.StructArray] = [] + new_child_fields: list[pa.Field] | None = None + + for chunk in col.chunks: + new_child_arrays: list[pa.Array] = [] + resolved_fields: list[pa.Field] = [] + + for child_field in child_fields: + child_arr = chunk.field(child_field.name) + # Wrap child array into a single-chunk ChunkedArray for _apply_field. + child_chunked = pa.chunked_array([child_arr], type=child_arr.type) + new_child_chunked, new_child_field = _apply_field( + child_chunked, child_field, registry + ) + # from_storage produces a non-chunked Array; use combine_chunks for single chunk. + new_child_arrays.append(new_child_chunked.combine_chunks()) + resolved_fields.append(new_child_field) + + # Preserve the original null bitmap so struct-level nulls survive wrapping. + # StructArray.from_arrays() defaults to all-valid without an explicit mask. + null_mask = chunk.is_null() if chunk.null_count > 0 else None + new_struct = pa.StructArray.from_arrays( + new_child_arrays, fields=resolved_fields, mask=null_mask + ) + new_chunks.append(new_struct) + if new_child_fields is None: + new_child_fields = resolved_fields + + assert new_child_fields is not None # col.chunks is non-empty if we reach here + new_struct_type = pa.struct(new_child_fields) + new_field = field.with_type(new_struct_type) + new_col = pa.chunked_array(new_chunks, type=new_struct_type) + return new_col, new_field diff --git a/src/orcapod/extension_types/dataclass_logical_type_factory.py b/src/orcapod/extension_types/dataclass_logical_type_factory.py new file mode 100644 index 00000000..5633ffd7 --- /dev/null +++ b/src/orcapod/extension_types/dataclass_logical_type_factory.py @@ -0,0 +1,363 @@ +"""DataclassLogicalType and DataclassLogicalTypeFactory. + +Provides the ``DataclassLogicalType`` logical type implementation and the +``DataclassLogicalTypeFactory`` that synthesises and reconstructs ``DataclassLogicalType`` +instances for Python dataclasses. + +Write path (``create_for_python_type``): + Iterates dataclass fields, delegates field Arrow-type resolution to the converter + via ``register_python_class``, and returns a ``DataclassLogicalType`` backed by + a ``pa.struct`` extension type. + +Read path (``reconstruct_from_arrow``): + Imports the dataclass by fully-qualified class name, resolves field annotations + against the (already bottom-up resolved) storage type, and returns a + ``DataclassLogicalType``. + +Category tag: ``"orcapod.dataclass"`` +""" + +from __future__ import annotations + +import dataclasses +import json +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) + +#: Category tag embedded in Arrow extension metadata. Used as the factory dispatch key. +DATACLASS_CATEGORY = "orcapod.dataclass" + + +class DataclassLogicalType: + """Logical type binding a Python dataclass to its Arrow extension type representation. + + Stores the dataclass's fully-qualified class name as the Arrow extension name + and a ``pa.struct`` of the dataclass fields as the storage type. + + No Arrow-type reasoning lives here — all field-type resolution is owned by the + converter and completed before this object is constructed. + + Args: + logical_name: Fully-qualified class name (e.g. ``"mymodule.sub.MyData"``). + Used as both the logical type name and the Arrow extension name. + python_type: The Python dataclass ``type`` object. + storage_type: The Arrow ``pa.StructType`` for the dataclass fields. + field_annotations: Ordered list of ``(field_name, python_annotation)`` pairs + matching the fields in ``storage_type``. + + Example: + >>> lt = DataclassLogicalType( + ... "mymod.Point", Point, + ... pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]), + ... [("x", int), ("y", int)], + ... ) + >>> lt.python_to_storage(Point(1, 2), converter) + {"x": 1, "y": 2} + """ + + def __init__( + self, + logical_name: str, + python_type: type, + storage_type: pa.StructType, + field_annotations: list[tuple[str, Any]], + ) -> None: + self._logical_name = logical_name + self._python_type = python_type + self._storage_type = storage_type + self._field_annotations = field_annotations + + _metadata = json.dumps({"category": DATACLASS_CATEGORY}).encode("utf-8") + self._arrow_ext_class = make_arrow_extension_type( + logical_name, storage_type, metadata=_metadata + ) + self._arrow_ext: pa.ExtensionType | None = None + # ``storage_type`` must not contain nested extension types (ET1 in DESIGN_ISSUES.md). + # On the write path, ``DataclassLogicalTypeFactory.create_for_python_type`` strips any + # top-level extension type from each field's Arrow type before inserting it into the + # struct. On the read path, ``reconstruct_from_arrow`` receives a ``storage_type`` + # already guaranteed storage-safe by ``register_storage_type``. + self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext: pl.BaseExtension | None = None + + @property + def logical_type_name(self) -> str: + """Fully-qualified class name used as the logical type identifier.""" + return self._logical_name + + @property + def python_type(self) -> type: + """The Python dataclass type this logical type represents.""" + return self._python_type + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for this dataclass. + + Returns: + A cached ``pa.ExtensionType`` instance with ``extension_name`` equal to + the fully-qualified class name and ``storage_type`` equal to the struct + of the dataclass fields. + """ + if self._arrow_ext is None: + self._arrow_ext = self._arrow_ext_class() + return self._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for this dataclass. + + Returns: + A cached ``pl.BaseExtension`` instance. + """ + if self._polars_ext is None: + self._polars_ext = self._polars_ext_class() + return self._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None) -> dict[str, Any]: + """Convert a dataclass instance to an Arrow-compatible struct dict. + + Iterates ``_field_annotations`` and delegates each field's conversion to + ``converter.python_to_storage``. + + Args: + value: A dataclass instance of type ``python_type``. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A dict mapping field names to their Arrow storage values. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "DataclassLogicalType.python_to_storage requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + return { + name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations + } + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None) -> Any: + """Reconstruct a dataclass instance from an Arrow struct dict. + + Args: + storage_value: A dict mapping field names to Arrow storage values. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A dataclass instance of type ``python_type``. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "DataclassLogicalType.storage_to_python requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + kwargs = { + name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations + } + return self._python_type(**kwargs) + + +class DataclassLogicalTypeFactory: + """Stateless factory that synthesises and reconstructs ``DataclassLogicalType`` instances. + + **Write path** (``create_for_python_type``): derives Arrow struct type from the + dataclass fields by delegating to ``converter.register_python_class`` per field. + + **Read path** (``reconstruct_from_arrow``): imports the dataclass by FQCN, matches + fields against the already-resolved ``storage_type``, and returns a + ``DataclassLogicalType``. + + Category tag: ``"orcapod.dataclass"`` + + Register with:: + + converter.register_logical_type_factory( + DataclassLogicalTypeFactory(), + category="orcapod.dataclass", + python_bases=[object], + ) + + Example: + >>> factory = DataclassLogicalTypeFactory() + >>> factory.supports_class(MyDataclass) + True + >>> factory.supports_class(str) + False + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if ``python_type`` is a dataclass. + + Args: + python_type: Any Python type. + + Returns: + True if ``dataclasses.is_dataclass(python_type)`` is True. + """ + return dataclasses.is_dataclass(python_type) and isinstance(python_type, type) + + def create_for_python_type( + self, + python_type: type, + converter: TypeConverterProtocol, + ) -> DataclassLogicalType: + """Synthesise a ``DataclassLogicalType`` for a Python dataclass (write path). + + Derives the FQCN, obtains type hints, and resolves each field's Arrow type + via ``converter.register_python_class``. Rejects local / unnamed classes. + + Args: + python_type: A Python dataclass type. + converter: The active converter for field-type resolution. + + Returns: + A ``DataclassLogicalType`` ready for registration. + + Raises: + ValueError: If ``python_type`` is a local class (``__qualname__`` contains + ``""``). + """ + import typing + + fqcn = f"{python_type.__module__}.{python_type.__qualname__}" + if "" in fqcn: + raise ValueError( + f"Cannot register local class {python_type!r} as a DataclassLogicalType — " + f"local classes have no stable fully-qualified class name and cannot be " + f"reconstructed on read. Define the dataclass at module level." + ) + + try: + hints = typing.get_type_hints(python_type) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {python_type!r}: {exc}" + ) from exc + + arrow_fields = [] + field_annotations = [] + for field in dataclasses.fields(python_type): + if not field.init: + continue + annotation = hints.get(field.name, Any) + arrow_type = converter.register_python_class(annotation) + # register_python_class returns a storage-safe type: may be extension at the + # top level, but struct fields are always plain. Strip the top-level extension + # type here before inserting into the struct (ET1; see DESIGN_ISSUES.md). + if isinstance(arrow_type, pa.ExtensionType): + arrow_type = arrow_type.storage_type + arrow_fields.append(pa.field(field.name, arrow_type)) + field_annotations.append((field.name, annotation)) + + storage_type = pa.struct(arrow_fields) + logger.debug("DataclassLogicalTypeFactory: synthesised %r for %r", fqcn, python_type) + return DataclassLogicalType(fqcn, python_type, storage_type, field_annotations) + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + converter: TypeConverterProtocol, + ) -> DataclassLogicalType: + """Reconstruct a ``DataclassLogicalType`` from Arrow schema metadata (read path). + + Imports the dataclass from its FQCN (``arrow_extension_name``), then matches + the dataclass field annotations against the fields in ``storage_type``. + ``storage_type`` is already bottom-up resolved by ``register_storage_type`` + before this method is called. + + Args: + arrow_extension_name: FQCN of the dataclass (Arrow extension name). + storage_type: Already-resolved ``pa.StructType`` for the dataclass fields. + metadata: Full parsed metadata JSON dict (always contains ``"category"``). + converter: The active converter (not needed here but required by protocol). + + Returns: + A ``DataclassLogicalType`` ready for registration. + + Raises: + ImportError: If the class cannot be imported from ``arrow_extension_name``. + ValueError: If ``storage_type`` is not a struct type. + """ + import typing + + if not pa.types.is_struct(storage_type): + raise ValueError( + f"DataclassLogicalTypeFactory.reconstruct_from_arrow: expected a struct " + f"storage type for {arrow_extension_name!r}, got {storage_type!r}." + ) + + # Import class from FQCN using longest-prefix module walk + cls = _import_from_fqcn(arrow_extension_name) + + try: + hints = typing.get_type_hints(cls) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {cls!r}: {exc}" + ) from exc + + field_annotations = [] + for field in dataclasses.fields(cls): + if not field.init: + continue + annotation = hints.get(field.name, Any) + # Register any logical type the field annotation maps to (registration + # completeness invariant: all nested logical types must be registered when + # the outer type is registered). The return value is discarded; only the + # side effect of registration matters here. + converter.register_python_class(annotation) + field_annotations.append((field.name, annotation)) + + logger.debug( + "DataclassLogicalTypeFactory: reconstructed %r from Arrow", arrow_extension_name + ) + return DataclassLogicalType( + arrow_extension_name, cls, storage_type, field_annotations + ) + + +def _import_from_fqcn(fqcn: str) -> type: + """Import a dataclass from its fully-qualified class name. + + Delegates the module-prefix walk to ``type_utils._walk_fqcn``, then + validates the resolved object is a dataclass type. + + Args: + fqcn: Fully-qualified class name, e.g. ``"mypackage.sub.MyClass"``. + + Returns: + The imported dataclass type. + + Raises: + ImportError: If no valid module+attribute split can be found, or if the + resolved object is not a dataclass type. + """ + from orcapod.extension_types.type_utils import _walk_fqcn + + obj: Any = _walk_fqcn(fqcn) + if not dataclasses.is_dataclass(obj) or not isinstance(obj, type): + raise ImportError( + f"{fqcn!r} does not resolve to a dataclass type." + ) + return obj diff --git a/src/orcapod/extension_types/protocols.py b/src/orcapod/extension_types/protocols.py index e3f6045c..95c8d83c 100644 --- a/src/orcapod/extension_types/protocols.py +++ b/src/orcapod/extension_types/protocols.py @@ -1,13 +1,10 @@ """Protocol definitions for the Arrow/Polars extension type system. -This module defines ``ExtensionTypeConverter`` — the contract for all -converters that map between Python objects and their Arrow extension type -storage representation. - -Note: - This module is part of the parallel-build phase. The old - ``SemanticStructConverterProtocol`` in ``protocols/semantic_types_protocols.py`` - is untouched; it is removed in PLT-1660. +This module defines ``TypeConverterProtocol``, ``LogicalTypeProtocol``, and +``LogicalTypeFactoryProtocol`` — the contracts for the converter, for logical +type implementations that bind a Python class to its Arrow and Polars extension +type representation, and for factories that auto-construct such implementations +from Arrow schema metadata. """ from __future__ import annotations @@ -15,71 +12,198 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: + import polars as pl import pyarrow as pa @runtime_checkable -class ExtensionTypeConverter(Protocol): - """Protocol for Arrow/Polars extension-type-backed converters. - - Declares the full contract for a converter that maps between Python - objects and their Arrow extension type storage representation. This - protocol is Arrow I/O only — hashing is not a converter responsibility. - - Attributes: - extension_name: Fully-qualified Python class name used as the - ``ARROW:extension:name`` metadata value (e.g. ``"pathlib.Path"``). - Must be unique across all registered converters. By convention - equals the FQCN, but any unique string is valid. - extension_metadata: Category tag encoded as ``ARROW:extension:metadata`` - (e.g. ``b"orcapod.dataclass"``). Used by the registry to locate - the right category handler at read time. May be ``None``. - storage_type: The underlying Arrow ``pa.DataType`` used for physical - storage (e.g. ``pa.large_string()``, ``pa.binary(16)``, - ``pa.struct(...)``). Not used as an identity signal — identity - is determined solely by ``extension_name``. - python_type: The Python class this converter handles. +class TypeConverterProtocol(Protocol): + """Minimal protocol exposing what factories and logical types need from the converter. + + Placed in ``extension_types/protocols.py`` to avoid circular imports. + ``UniversalTypeConverter`` is the canonical implementation. """ - @property - def extension_name(self) -> str: - """Fully-qualified Python class name; stored as ``ARROW:extension:name``.""" + def register_python_class(self, annotation: Any) -> "pa.DataType": + """Traverse a Python annotation, register any logical types found, and return + the storage-safe Arrow type. + + The returned type may be a ``pa.ExtensionType`` at the top level for registered + classes (e.g. ``UUID`` → ``orcapod.uuid`` extension type), but struct fields and + list value types at any depth are always plain (non-extension) Arrow types. + + Args: + annotation: A Python type or generic alias (e.g. ``list[str]``, + ``Optional[uuid.UUID]``, a dataclass type). + + Returns: + A storage-safe ``pa.DataType``. May be ``pa.ExtensionType`` at the top level; + never contains nested extension types in struct/list fields. + """ ... - @property - def extension_metadata(self) -> bytes | None: - """Category tag; stored as ``ARROW:extension:metadata``. May be ``None``.""" + def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Traverse an Arrow type bottom-up, registering extension types, and return a + storage-safe type. + + The returned type may be a ``pa.ExtensionType`` at the top level, but struct fields + and list value types at any depth are always plain (non-extension) Arrow types. + This invariant makes the return value safe to use as a struct field or list element + type without further stripping. + + Args: + arrow_type: An Arrow type to traverse and register. + + Returns: + A storage-safe ``pa.DataType``. + """ + ... + + def python_to_storage(self, value: Any, annotation: Any) -> Any: + """Convert a Python value to its Arrow storage representation.""" ... + def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: + """Convert an Arrow storage value back to a Python object.""" + ... + + def apply_extension_types(self, table: "pa.Table") -> "pa.Table": + """Re-wrap table columns into their registered Arrow extension types.""" + ... + + def register_discovered_extensions(self, schema: "pa.Schema") -> None: + """Register any extension types found in ``schema`` that are not yet known.""" + ... + + def load_extension_types(self, table: "pa.Table") -> "pa.Table": + """Register and apply extension types for *table* in one step.""" + ... + + def register_arrow_extension( + self, + arrow_extension_name: str, + extension_metadata: "bytes | None", + storage_type: "pa.DataType", + ) -> "pa.DataType": + """Register an extension type from (name, metadata, storage_type) and return the Arrow type.""" + ... + + +@runtime_checkable +class LogicalTypeProtocol(Protocol): + """Protocol for Arrow/Polars extension-type-backed logical types. + + A ``LogicalTypeProtocol`` is a three-way binding between a unique logical type name + (orcapod's identifier), a Python class, and Arrow/Polars extension types. + Each implementation *owns* its Arrow and Polars extension types by providing + them directly via ``get_arrow_extension_type`` and ``get_polars_extension_type``. + + This protocol is Arrow I/O only — hashing is not a logical type responsibility. + """ + @property - def storage_type(self) -> pa.DataType: - """Underlying Arrow storage type. Any ``pa.DataType`` is valid.""" + def logical_type_name(self) -> str: + """Unique orcapod identifier for this logical type (e.g. ``"orcapod.uuid"``).""" ... @property def python_type(self) -> type: - """The Python class this converter handles.""" + """The Python class this logical type represents.""" ... - def python_to_storage(self, value: Any) -> Any: + def get_arrow_extension_type(self) -> "pa.ExtensionType": + """Return the Arrow extension type for this logical type.""" + ... + + def get_polars_extension_type(self) -> "pl.BaseExtension": + """Return an instance of the Polars extension type for this logical type.""" + ... + + def python_to_storage(self, value: Any, converter: "TypeConverterProtocol | None") -> Any: """Convert a Python value to its Arrow storage representation. Args: value: A Python object of type ``python_type``. + converter: The active ``TypeConverterProtocol`` for recursive delegation. Returns: - A value suitable for use as an Arrow scalar or array element - of type ``storage_type``. + A value suitable for Arrow storage. """ ... - def storage_to_python(self, storage_value: Any) -> Any: + def storage_to_python(self, storage_value: Any, converter: "TypeConverterProtocol | None") -> Any: """Convert an Arrow storage value back to a Python object. Args: - storage_value: A scalar or array element of type ``storage_type``. + storage_value: A scalar or array element from the Arrow storage array. + converter: The active ``TypeConverterProtocol`` for recursive delegation. Returns: A Python object of type ``python_type``. """ ... + + +@runtime_checkable +class LogicalTypeFactoryProtocol(Protocol): + """Protocol for factories that synthesize or reconstruct ``LogicalTypeProtocol`` instances. + + Bridges two directions: the write path (``create_for_python_type``) and the read + path (``reconstruct_from_arrow``). Both methods receive ``converter`` instead of + ``registry`` so all traversal flows through the converter. + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if this factory can synthesize a LogicalType for ``python_type``. + + Used as a probe during write-side MRO dispatch in ``register_python_class``. + + Args: + python_type: The Python class to probe. + + Returns: + True if this factory handles ``python_type``. + """ + ... + + def create_for_python_type( + self, + python_type: type, + converter: "TypeConverterProtocol", + ) -> LogicalTypeProtocol: + """Synthesize a LogicalType for the given Python class (write path). + + Args: + python_type: The concrete Python class to synthesize a LogicalType for. + converter: The active converter for recursive field-type resolution. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot construct a type for the given class. + """ + ... + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: "pa.DataType", + metadata: dict[str, Any], + converter: "TypeConverterProtocol", + ) -> LogicalTypeProtocol: + """Reconstruct a LogicalType from Arrow schema metadata (read path). + + Args: + arrow_extension_name: The Arrow extension type name from the schema. + storage_type: The underlying Arrow storage type (already resolved bottom-up). + metadata: Full parsed metadata JSON dict. Always contains ``"category"``. + converter: The active converter for recursive field-type resolution. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot reconstruct a type for the given name. + """ + ... diff --git a/src/orcapod/extension_types/pydantic_logical_type_factory.py b/src/orcapod/extension_types/pydantic_logical_type_factory.py new file mode 100644 index 00000000..e40baf4f --- /dev/null +++ b/src/orcapod/extension_types/pydantic_logical_type_factory.py @@ -0,0 +1,365 @@ +"""PydanticLogicalType and PydanticLogicalTypeFactory. + +Provides the ``PydanticLogicalType`` logical type implementation and the +``PydanticLogicalTypeFactory`` that synthesises and reconstructs +``PydanticLogicalType`` instances for pydantic v2 ``BaseModel`` subclasses. + +Write path (``create_for_python_type``): + Iterates model fields via ``model_fields`` (pydantic v2 API), delegates + field Arrow-type resolution to the converter via ``register_python_class``, + and returns a ``PydanticLogicalType`` backed by a ``pa.struct`` extension + type. + +Read path (``reconstruct_from_arrow``): + Imports the model by fully-qualified class name, resolves field annotations + against the (already bottom-up resolved) storage type, and returns a + ``PydanticLogicalType``. + +Category tag: ``"orcapod.pydantic"`` +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) + +#: Category tag embedded in Arrow extension metadata. Used as the factory dispatch key. +PYDANTIC_CATEGORY = "orcapod.pydantic" + + +class PydanticLogicalType: + """Logical type binding a pydantic ``BaseModel`` subclass to its Arrow extension type. + + Stores the model's fully-qualified class name as the Arrow extension name + and a ``pa.struct`` of the model fields as the storage type. + + No Arrow-type reasoning lives here — all field-type resolution is owned by + the converter and completed before this object is constructed. + + Args: + logical_name: Fully-qualified class name (e.g. ``"mymodule.sub.MyModel"``). + Used as both the logical type name and the Arrow extension name. + python_type: The pydantic ``BaseModel`` subclass. + storage_type: The Arrow ``pa.StructType`` for the model fields. + field_annotations: Ordered list of ``(field_name, python_annotation)`` + pairs matching the fields in ``storage_type``. + + Example: + >>> lt = PydanticLogicalType( + ... "mymod.Point", Point, + ... pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]), + ... [("x", int), ("y", int)], + ... ) + >>> lt.python_to_storage(Point(x=1, y=2), converter) + {"x": 1, "y": 2} + """ + + def __init__( + self, + logical_name: str, + python_type: type, + storage_type: pa.StructType, + field_annotations: list[tuple[str, Any]], + ) -> None: + self._logical_name = logical_name + self._python_type = python_type + self._storage_type = storage_type + self._field_annotations = field_annotations + + _metadata = json.dumps({"category": PYDANTIC_CATEGORY}).encode("utf-8") + self._arrow_ext_class = make_arrow_extension_type( + logical_name, storage_type, metadata=_metadata + ) + self._arrow_ext: pa.ExtensionType | None = None + # ``storage_type`` must not contain nested extension types (ET1 in DESIGN_ISSUES.md). + # On the write path, ``PydanticLogicalTypeFactory.create_for_python_type`` strips any + # top-level extension type from each field's Arrow type before inserting it into the + # struct. On the read path, ``reconstruct_from_arrow`` receives a ``storage_type`` + # already guaranteed storage-safe by ``register_storage_type``. + self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext: pl.BaseExtension | None = None + + @property + def logical_type_name(self) -> str: + """Fully-qualified class name used as the logical type identifier.""" + return self._logical_name + + @property + def python_type(self) -> type: + """The pydantic ``BaseModel`` subclass this logical type represents.""" + return self._python_type + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for this model. + + Returns: + A cached ``pa.ExtensionType`` instance with ``extension_name`` equal to + the fully-qualified class name and ``storage_type`` equal to the struct + of the model fields. + """ + if self._arrow_ext is None: + self._arrow_ext = self._arrow_ext_class() + return self._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for this model. + + Returns: + A cached ``pl.BaseExtension`` instance. + """ + if self._polars_ext is None: + self._polars_ext = self._polars_ext_class() + return self._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None) -> dict[str, Any]: + """Convert a pydantic model instance to an Arrow-compatible struct dict. + + Iterates ``_field_annotations`` and delegates each field's conversion to + ``converter.python_to_storage``. + + Args: + value: A pydantic model instance of type ``python_type``. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A dict mapping field names to their Arrow storage values. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "PydanticLogicalType.python_to_storage requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + return { + name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations + } + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None) -> Any: + """Reconstruct a pydantic model instance from an Arrow struct dict. + + Args: + storage_value: A dict mapping field names to Arrow storage values. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A pydantic model instance of type ``python_type``. Pydantic validation + runs on construction, ensuring the model is always in a valid state. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "PydanticLogicalType.storage_to_python requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + kwargs = { + name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations + } + return self._python_type(**kwargs) + + +class PydanticLogicalTypeFactory: + """Stateless factory that synthesises and reconstructs ``PydanticLogicalType`` instances. + + **Write path** (``create_for_python_type``): derives Arrow struct type from the + model fields by delegating to ``converter.register_python_class`` per field. + Only fields in ``model_fields`` are stored — computed fields and private + attributes are excluded. + + **Read path** (``reconstruct_from_arrow``): imports the model by FQCN, matches + fields against the already-resolved ``storage_type``, and returns a + ``PydanticLogicalType``. + + Category tag: ``"orcapod.pydantic"`` + + Register with:: + + from pydantic import BaseModel + converter.register_logical_type_factory( + PydanticLogicalTypeFactory(), + category="orcapod.pydantic", + python_bases=[BaseModel], + ) + + Example: + >>> factory = PydanticLogicalTypeFactory() + >>> factory.supports_class(MyModel) + True + >>> factory.supports_class(str) + False + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if ``python_type`` is a pydantic ``BaseModel`` subclass. + + Args: + python_type: Any Python type. + + Returns: + True if ``python_type`` is a ``BaseModel`` subclass. + """ + from pydantic import BaseModel + return isinstance(python_type, type) and issubclass(python_type, BaseModel) + + def create_for_python_type( + self, + python_type: type, + converter: TypeConverterProtocol, + ) -> PydanticLogicalType: + """Synthesise a ``PydanticLogicalType`` for a pydantic model (write path). + + Derives the FQCN, obtains type hints, and resolves each field's Arrow type + via ``converter.register_python_class``. Only fields present in + ``model_fields`` are stored — computed fields and private attributes are + excluded. Rejects local / unnamed classes. + + Args: + python_type: A pydantic ``BaseModel`` subclass. + converter: The active converter for field-type resolution. + + Returns: + A ``PydanticLogicalType`` ready for registration. + + Raises: + ValueError: If ``python_type`` is a local class (``__qualname__`` contains + ``""``). + """ + import typing + + fqcn = f"{python_type.__module__}.{python_type.__qualname__}" + if "" in fqcn: + raise ValueError( + f"Cannot register local class {python_type!r} as a PydanticLogicalType — " + f"local classes have no stable fully-qualified class name and cannot be " + f"reconstructed on read. Define the model at module level." + ) + + try: + hints = typing.get_type_hints(python_type) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {python_type!r}: {exc}" + ) from exc + + arrow_fields = [] + field_annotations = [] + for field_name in python_type.model_fields: + annotation = hints.get(field_name, Any) + arrow_type = converter.register_python_class(annotation) + # Strip top-level extension type before inserting into the struct (ET1; + # see DESIGN_ISSUES.md): Arrow cannot represent extension types inside + # struct field types. + if isinstance(arrow_type, pa.ExtensionType): + arrow_type = arrow_type.storage_type + arrow_fields.append(pa.field(field_name, arrow_type)) + field_annotations.append((field_name, annotation)) + + storage_type = pa.struct(arrow_fields) + logger.debug("PydanticLogicalTypeFactory: synthesised %r for %r", fqcn, python_type) + return PydanticLogicalType(fqcn, python_type, storage_type, field_annotations) + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + converter: TypeConverterProtocol, + ) -> PydanticLogicalType: + """Reconstruct a ``PydanticLogicalType`` from Arrow schema metadata (read path). + + Imports the model from its FQCN (``arrow_extension_name``), then matches + the model field annotations against the fields in ``storage_type``. + ``storage_type`` is already bottom-up resolved by ``register_storage_type`` + before this method is called. + + Args: + arrow_extension_name: FQCN of the pydantic model (Arrow extension name). + storage_type: Already-resolved ``pa.StructType`` for the model fields. + metadata: Full parsed metadata JSON dict (always contains ``"category"``). + converter: The active converter (used for registration completeness invariant). + + Returns: + A ``PydanticLogicalType`` ready for registration. + + Raises: + ImportError: If the class cannot be imported from ``arrow_extension_name``. + ValueError: If ``storage_type`` is not a struct type. + """ + import typing + + if not pa.types.is_struct(storage_type): + raise ValueError( + f"PydanticLogicalTypeFactory.reconstruct_from_arrow: expected a struct " + f"storage type for {arrow_extension_name!r}, got {storage_type!r}." + ) + + cls = _import_pydantic_model_from_fqcn(arrow_extension_name) + + try: + hints = typing.get_type_hints(cls) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {cls!r}: {exc}" + ) from exc + + field_annotations = [] + for field_name in cls.model_fields: + annotation = hints.get(field_name, Any) + # Register any logical type the field annotation maps to (registration + # completeness invariant: all nested logical types must be registered when + # the outer type is registered). The return value is discarded. + converter.register_python_class(annotation) + field_annotations.append((field_name, annotation)) + + logger.debug( + "PydanticLogicalTypeFactory: reconstructed %r from Arrow", arrow_extension_name + ) + return PydanticLogicalType( + arrow_extension_name, cls, storage_type, field_annotations + ) + + +def _import_pydantic_model_from_fqcn(fqcn: str) -> type: + """Import a pydantic ``BaseModel`` subclass from its fully-qualified class name. + + Delegates the module-prefix walk to ``type_utils._walk_fqcn``, then + validates the resolved object is a ``BaseModel`` subclass. + + Args: + fqcn: Fully-qualified class name, e.g. ``"mypackage.sub.MyModel"``. + + Returns: + The imported ``BaseModel`` subclass. + + Raises: + ImportError: If no valid module+attribute split can be found, or if the + resolved object is not a ``BaseModel`` subclass. + """ + from pydantic import BaseModel + from orcapod.extension_types.type_utils import _walk_fqcn + + obj: Any = _walk_fqcn(fqcn) + if not (isinstance(obj, type) and issubclass(obj, BaseModel)): + raise ImportError( + f"{fqcn!r} does not resolve to a pydantic BaseModel subclass." + ) + return obj diff --git a/src/orcapod/extension_types/registry.py b/src/orcapod/extension_types/registry.py new file mode 100644 index 00000000..32090242 --- /dev/null +++ b/src/orcapod/extension_types/registry.py @@ -0,0 +1,381 @@ +"""Registry for LogicalType instances. + +Registering a logical type automatically registers the corresponding +extension type in both PyArrow's and Polars' global registries. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING, Iterable + +from orcapod.extension_types.protocols import LogicalTypeProtocol, LogicalTypeFactoryProtocol +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) + + +def _sanitize(name: str) -> str: + """Replace non-alphanumeric characters with underscores. + + Used to produce a valid Python identifier for the dynamically created + ``pa.ExtensionType`` subclass name. + """ + return re.sub(r"[^A-Za-z0-9]", "_", name) + + +def make_arrow_extension_type( + extension_name: str, + storage_type: pa.DataType, + metadata: bytes | None = None, +) -> type[pa.ExtensionType]: + """Synthesise and return a ``pa.ExtensionType`` subclass. + + Returns the *class*, not an instance — callers instantiate it inside their + ``get_arrow_extension_type()`` implementation. Returning the class preserves + the option to create multiple instances or future parameterised variants from + the same class. + + This is a low-level building block. Each ``LogicalType`` implementation acts + as a factory: it creates and owns the ``pa.ExtensionType`` instance it requires + and exposes it via ``get_arrow_extension_type()``. See PLT-1656 for the + built-in implementations (``Path``, ``UPath``, ``UUID``). + + Args: + extension_name: The Arrow extension name (``ARROW:extension:name``). + storage_type: The underlying Arrow storage type. + metadata: Optional bytes stored as ``ARROW:extension:metadata``. + Defaults to ``None`` (serialised as empty bytes). + + ``metadata`` can optionally encode a **LogicalTypeProtocol category** as a + UTF-8 JSON object with at least a ``"category"`` key + (e.g. ``b'{"category": "Dataclass"}'``, + ``b'{"category": "Pydantic", "pydantic_version": 2}'``). + A ``LogicalTypeFactoryProtocol`` (see + ``LogicalTypeFactoryProtocol.reconstruct_from_arrow``) dispatches on the + ``"category"`` value when reading schemas from IPC or Parquet files and + uses it to auto-generate the correct ``LogicalTypeProtocol`` implementation + for the specific Python class within that category, without requiring + explicit prior registration. + + Returns: + A ``pa.ExtensionType`` subclass. Call it with no arguments to obtain + an instance suitable for passing to ``pa.register_extension_type`` or + returning from ``get_arrow_extension_type()``. + """ + _name, _storage, _metadata = extension_name, storage_type, metadata or b"" + + def _deserialize(cls, storage_type: pa.DataType, serialized: bytes) -> pa.ExtensionType: + # __arrow_ext_deserialize__ reconstructs the type descriptor from schema + # metadata (called once per IPC/Parquet read, not per value). Validate the + # incoming storage_type and serialized bytes against the expected values so + # that reading a file where the same extension name was written with different + # parameters raises immediately rather than silently producing wrong data. + if storage_type != _storage: + raise ValueError( + f"Arrow extension type '{_name}': expected storage_type " + f"{_storage!r} but got {storage_type!r}." + ) + if serialized != _metadata: + raise ValueError( + f"Arrow extension type '{_name}': expected metadata " + f"{_metadata!r} but got {serialized!r}." + ) + return cls() + + return type( + f"_ArrowExt_{_sanitize(extension_name)}", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _storage, _name), + "__arrow_ext_serialize__": lambda self: _metadata, + "__arrow_ext_deserialize__": classmethod(_deserialize), + }, + ) + + +def make_polars_extension_type( + extension_name: str, + arrow_storage_type: pa.DataType, + metadata: str | None = None, +) -> type[pl.BaseExtension]: + """Synthesise and return a ``pl.BaseExtension`` subclass. + + Derives the Polars storage dtype from *arrow_storage_type* via + ``pl.from_arrow``. Returns the *class*; callers instantiate it inside + ``get_polars_extension_type()``. + + The returned class uses the Arrow extension name as its registration name + (the same name passed to ``pl.register_extension_type``), so that Polars + correctly maps Arrow extension columns on read. + + **Limitation — nested extension types not supported:** ``arrow_storage_type`` + must not contain any ``pa.ExtensionType`` nodes (e.g. as struct fields or + list element types). Polars's Arrow IPC bridge can handle a top-level + extension type via ``pl.BaseExtension``, but raises + ``ArrowNotImplementedError: extension`` when it encounters an extension type + nested inside a struct or list during dtype inference. Callers must ensure + ``arrow_storage_type`` is storage-safe (no nested extension type nodes) before + passing it here. Types produced by ``register_python_class`` and + ``register_storage_type`` satisfy this invariant, but arbitrary + ``pa.DataType`` values do not. This is tracked as design issue ET1 in + ``DESIGN_ISSUES.md``. + + Args: + extension_name: The extension type name used for Polars registration. + Must match the Arrow extension name so Polars can round-trip the + type through Arrow IPC. + arrow_storage_type: The Arrow storage type. Must not contain nested + ``pa.ExtensionType`` nodes; see limitation note above. Converted + once to the corresponding Polars dtype via ``pl.from_arrow``. + metadata: Optional metadata string stored as ``metadata_str`` in the + Polars extension. Defaults to ``None``. + + Returns: + A ``pl.BaseExtension`` subclass. Call it with no arguments to obtain + an instance suitable for passing to ``pl.register_extension_type`` or + returning from ``get_polars_extension_type()``. + """ + _name = extension_name + _polars_dtype = pl.from_arrow(pa.array([], type=arrow_storage_type)).dtype + _metadata = metadata + + def __init__(self: pl.BaseExtension) -> None: + pl.BaseExtension.__init__(self, _name, _polars_dtype, _metadata) + + @classmethod # type: ignore[misc] + def ext_from_params( + cls: type[pl.BaseExtension], + ext_name: str, + storage_dtype: pl.PolarsDataType, + metadata_str: str | None, + ) -> pl.BaseExtension: + return cls() + + return type( + f"_PolarsExt_{_sanitize(extension_name)}", + (pl.BaseExtension,), + { + "__init__": __init__, + "ext_from_params": ext_from_params, + }, + ) + + +class LogicalTypeRegistry: + """Registry for ``LogicalType`` instances. + + Maintains a three-way binding: ``(logical_type_name, arrow_extension_name, + python_type)`` → ``LogicalType``. Each key participates in at most one + binding within a registry instance. + + Registering a logical type side-effect-registers the corresponding extension + type in PyArrow's and Polars' global registries. Pre-existing types (those + already registered externally in the global Arrow or Polars registries) are + accepted silently — the binding is stored without error. + + The standard access path for the default registry is + ``get_default_context().logical_type_registry`` or the convenience function + ``get_default_logical_type_registry()`` from ``orcapod.contexts``. + Thread-safety is deferred. + + An optional ``logical_types`` list can be passed at construction time to + pre-register one or more ``LogicalTypeProtocol`` instances immediately, following + the same pattern as the ``logical_types`` constructor argument used by + other registries in this package. + + An optional ``factories`` list can also be passed to pre-register + ``LogicalTypeFactoryProtocol`` instances at construction time. Each entry is a + dict with keys ``factory`` (the factory instance), ``category`` (optional str), + and ``python_bases`` (optional list of types). + + Example: + >>> registry = LogicalTypeRegistry() + >>> registry.register_logical_type(my_logical_type) + >>> lt = registry.get_by_logical_name("orcapod.uuid") + + >>> # Pre-register types at construction: + >>> registry = LogicalTypeRegistry(logical_types=[path_lt, uuid_lt]) + """ + + def __init__( + self, + logical_types: list[LogicalTypeProtocol] | None = None, + factories: list[dict] | None = None, + ) -> None: + self._by_logical_name: dict[str, LogicalTypeProtocol] = {} + self._by_arrow_name: dict[str, LogicalTypeProtocol] = {} + self._by_python_type: dict[type, LogicalTypeProtocol] = {} + self._category_factories: dict[str, LogicalTypeFactoryProtocol] = {} + self._python_class_factories: dict[type, LogicalTypeFactoryProtocol] = {} + for lt in (logical_types or []): + self.register_logical_type(lt) + for entry in (factories or []): + self.register_logical_type_factory( + entry["factory"], + category=entry.get("category"), + python_bases=entry.get("python_bases", []), + ) + + def register_logical_type(self, logical_type: LogicalTypeProtocol) -> None: + """Register *logical_type* and its PyArrow/Polars extension types. + + Args: + logical_type: A ``LogicalTypeProtocol`` instance to register. + + Raises: + ValueError: If any of the three keys (``logical_type_name``, + Arrow extension name, ``python_type``) is already bound to a + *different* ``LogicalTypeProtocol`` in this registry. + """ + arrow_ext = logical_type.get_arrow_extension_type() + arrow_ext_name = arrow_ext.extension_name + py_type = logical_type.python_type + logical_name = logical_type.logical_type_name + + existing_by_logical = self._by_logical_name.get(logical_name) + existing_by_arrow = self._by_arrow_name.get(arrow_ext_name) + existing_by_python = self._by_python_type.get(py_type) + + # Triplet conflict check: raise if any key is bound to a different instance. + for existing, label, key in [ + (existing_by_logical, "logical_type_name", logical_name), + (existing_by_arrow, "arrow_extension_name", arrow_ext_name), + (existing_by_python, "python_type", py_type.__qualname__), + ]: + if existing is not None and existing is not logical_type: + raise ValueError( + f"Cannot register logical type '{logical_name}': " + f"{label} {key!r} is already bound to " + f"'{existing.logical_type_name}'." + ) + + # Idempotent check: all three keys already bound to this same instance. + if ( + existing_by_logical is logical_type + and existing_by_arrow is logical_type + and existing_by_python is logical_type + ): + return + + # Register Arrow extension type. ArrowKeyError means the name is already + # in PyArrow's global registry (pre-existing type or another registry + # instance). Accept silently — PLT-1669 adds post-error validation. + try: + pa.register_extension_type(arrow_ext) + except pa.lib.ArrowKeyError: + pass + + # Register Polars extension type. ValueError or ComputeError means already registered. + # Polars raises ValueError via its Python-level guard (_REGISTRY dict check), but + # raises polars.exceptions.ComputeError when the lower-level Rust registry detects + # the duplicate (e.g. when the Polars Python dict was already cleared or bypassed). + # Both errors mean "already registered" — accept silently. + polars_ext = logical_type.get_polars_extension_type() + polars_ext_class = type(polars_ext) + try: + pl.register_extension_type(arrow_ext_name, polars_ext_class) + except (ValueError, pl.exceptions.ComputeError): + pass + + # Store three-way binding. + self._by_logical_name[logical_name] = logical_type + self._by_arrow_name[arrow_ext_name] = logical_type + self._by_python_type[py_type] = logical_type + + def get_by_logical_name(self, name: str) -> LogicalTypeProtocol | None: + """Return the logical type registered under *name*, or ``None``.""" + return self._by_logical_name.get(name) + + def get_by_python_type(self, python_type: type) -> LogicalTypeProtocol | None: + """Return the logical type for *python_type*, or ``None``. + + Checks exact match first, then falls back to an ``issubclass`` scan. + When multiple registered types are superclasses of *python_type*, the + one registered first wins (insertion-order dict, Python 3.7+). + """ + result = self._by_python_type.get(python_type) + if result is not None: + return result + for registered_type, registered_lt in self._by_python_type.items(): + try: + if issubclass(python_type, registered_type): + return registered_lt + except TypeError: + continue + return None + + def get_by_arrow_extension_name(self, arrow_name: str) -> LogicalTypeProtocol | None: + """Return the logical type registered under *arrow_name*, or ``None``.""" + return self._by_arrow_name.get(arrow_name) + + def register_logical_type_factory( + self, + factory: LogicalTypeFactoryProtocol, + *, + category: str | None = None, + python_bases: Iterable[type] = (), + ) -> None: + """Register a factory on one or both dispatch axes. + + A single factory instance can be registered for multiple ``python_bases`` + at once — pass a list with all the base classes it should handle. + + Args: + factory: The factory to register. + category: If given, registers factory as the read-side handler for Arrow + extension types whose metadata contains this category string. Raises + ``ValueError`` if a different factory is already registered for this + category. + python_bases: Zero or more Python base classes. Registers factory as the + write-side handler for each base. A single factory may cover any + number of bases. Raises ``ValueError`` if a *different* factory is + already registered for a given base. + + Raises: + ValueError: If neither ``category`` nor ``python_bases`` is provided. + ValueError: If a different factory is already registered for a given key. + """ + python_bases_list = list(python_bases) + if category is None and not python_bases_list: + raise ValueError( + "At least one of 'category' or 'python_bases' must be provided." + ) + if category is not None: + existing = self._category_factories.get(category) + if existing is not None and existing is not factory: + raise ValueError( + f"Cannot register factory for category {category!r}: " + f"a different factory is already registered for this category." + ) + # Skip registration if this exact factory object is already bound to the category. + if existing is not factory: + self._category_factories[category] = factory + logger.debug( + "registered LogicalTypeFactory for category %r: %r", category, factory + ) + # Validate all bases before writing any (prevents partial mutation on error). + for base in python_bases_list: + existing = self._python_class_factories.get(base) + if existing is not None and existing is not factory: + raise ValueError( + f"Cannot register factory for python base {base!r}: " + f"a different factory is already registered for this base." + ) + for base in python_bases_list: + # Skip if this exact factory object is already bound to the base class + # (idempotent re-registration of the same factory is always a no-op). + if self._python_class_factories.get(base) is not factory: + self._python_class_factories[base] = factory + logger.debug( + "registered LogicalTypeFactory for python base %r: %r", base, factory + ) diff --git a/src/orcapod/extension_types/schema_walker.py b/src/orcapod/extension_types/schema_walker.py new file mode 100644 index 00000000..78b1c151 --- /dev/null +++ b/src/orcapod/extension_types/schema_walker.py @@ -0,0 +1,184 @@ +"""Recursive Arrow schema walker for extension type discovery. + +Given a ``pa.Schema`` or a single ``pa.Field``, walks the Arrow type tree +recursively and returns all extension-typed fields found at any depth of +nesting (struct, list, map, etc.). + +This is a pure discovery utility — it never triggers any registration. +""" + +from __future__ import annotations + +import dataclasses +import logging + +import pyarrow as pa + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class ExtensionTypeInfo: + """Metadata for a single Arrow extension type found in a schema. + + Attributes: + extension_name: The extension type's unique name stored as + ``ARROW:extension:name`` (e.g. ``"orcapod.path"``). + extension_metadata: The category tag stored as + ``ARROW:extension:metadata`` (e.g. ``b"orcapod.dataclass"``). + ``None`` when absent or serialised as empty bytes. + storage_type: The underlying Arrow storage type + (e.g. ``pa.large_string()``). + """ + + extension_name: str + extension_metadata: bytes | None + storage_type: pa.DataType + + +def walk_schema(schema: pa.Schema) -> list[ExtensionTypeInfo]: + """Walk *schema* and return all extension types found, deduplicated. + + Iterates every top-level field and descends recursively into struct, + list, and map container types. The result is deduplicated by + ``(extension_name, extension_metadata)``; the first occurrence of each + pair is kept. + + Args: + schema: A PyArrow schema to inspect. + + Returns: + Deduplicated list of ``ExtensionTypeInfo`` in depth-first, + first-seen order. Extension type storage types are not descended + into — only the logical schema type tree is walked. + """ + seen: set[tuple[str, bytes | None]] = set() + results: list[ExtensionTypeInfo] = [] + for i in range(len(schema)): + _collect(schema.field(i), seen, results) + return results + + +def walk_field(field: pa.Field) -> list[ExtensionTypeInfo]: + """Walk *field*'s type tree and return all extension types found, deduplicated. + + Args: + field: A PyArrow field to inspect. + + Returns: + Deduplicated list of ``ExtensionTypeInfo`` in depth-first, + first-seen order. Extension type storage types are not descended + into — only the logical schema type tree is walked. + """ + seen: set[tuple[str, bytes | None]] = set() + results: list[ExtensionTypeInfo] = [] + _collect(field, seen, results) + return results + + +def _collect( + field: pa.Field, + seen: set[tuple[str, bytes | None]], + results: list[ExtensionTypeInfo], +) -> None: + """Recursively walk *field* and accumulate ``ExtensionTypeInfo`` into *results*. + + Mutates *seen* and *results* in place. Stops descending once a field is + identified as extension-typed — the storage type of an extension type is + not descended into. + + Args: + field: The field to inspect. + seen: Deduplication set of ``(extension_name, extension_metadata)`` + pairs already appended to *results*. + results: Accumulator list. + """ + info = _detect_extension(field) + if info is not None: + key = (info.extension_name, info.extension_metadata) + if key not in seen: + logger.debug( + "schema_walker: found extension type %r (metadata=%r) in field %r", + info.extension_name, + info.extension_metadata, + field.name, + ) + seen.add(key) + results.append(info) + else: + logger.debug( + "schema_walker: skipping duplicate extension type %r in field %r", + info.extension_name, + field.name, + ) + return + + t = field.type + if pa.types.is_struct(t): + logger.debug( + "schema_walker: descending into struct field %r (%d sub-fields)", + field.name, + t.num_fields, + ) + for i in range(t.num_fields): + _collect(t.field(i), seen, results) + elif ( + pa.types.is_list(t) + or pa.types.is_large_list(t) + or pa.types.is_fixed_size_list(t) + or pa.types.is_list_view(t) + or pa.types.is_large_list_view(t) + ): + logger.debug("schema_walker: descending into list field %r", field.name) + # .value_field is guaranteed by Arrow's list type contract. + _collect(t.value_field, seen, results) + elif pa.types.is_map(t): + logger.debug("schema_walker: descending into map field %r", field.name) + # key_field and item_field are stable on pa.MapType since PyArrow 14; + # this project requires >= 20, so direct attribute access is safe. + _collect(t.key_field, seen, results) + _collect(t.item_field, seen, results) + + +def _detect_extension(field: pa.Field) -> ExtensionTypeInfo | None: + """Extract ``ExtensionTypeInfo`` from *field*, or ``None`` if not extension-typed. + + Checks two channels in order: + + 1. **In-memory ExtensionType channel** — ``isinstance(field.type, + pa.ExtensionType)`` is true. This fires whenever a ``pa.ExtensionType`` + instance is attached to the field, regardless of whether the type is + registered in PyArrow's process-global registry. The type object + carries the name, serialised metadata, and storage type. + 2. **Field-metadata channel** — ``field.metadata`` contains + ``b"ARROW:extension:name"``. The type survived a Parquet/IPC + round-trip as raw Arrow field metadata without a corresponding + in-memory ``pa.ExtensionType`` instance in this process. + + In both cases empty bytes metadata (``b""``) is normalised to ``None``. + + Args: + field: The field to inspect. + + Returns: + ``ExtensionTypeInfo`` if the field is extension-typed, else ``None``. + """ + if isinstance(field.type, pa.ExtensionType): + ext_type = field.type + raw_meta = ext_type.__arrow_ext_serialize__() + return ExtensionTypeInfo( + extension_name=ext_type.extension_name, + extension_metadata=raw_meta or None, + storage_type=ext_type.storage_type, + ) + + if field.metadata and b"ARROW:extension:name" in field.metadata: + name = field.metadata[b"ARROW:extension:name"].decode("utf-8") + raw_meta = field.metadata.get(b"ARROW:extension:metadata") + return ExtensionTypeInfo( + extension_name=name, + extension_metadata=raw_meta or None, + storage_type=field.type, + ) + + return None diff --git a/src/orcapod/extension_types/type_utils.py b/src/orcapod/extension_types/type_utils.py new file mode 100644 index 00000000..21487057 --- /dev/null +++ b/src/orcapod/extension_types/type_utils.py @@ -0,0 +1,117 @@ +"""Utility helpers for Python type annotation inspection and FQCN import. + +Used by the write-side registration trigger to extract leaf Python classes from +complex generic annotations like ``list[dict[A, list[B]]]``, and by logical type +factories to import classes from fully-qualified class names. +""" + +from __future__ import annotations + +import importlib +import typing +from typing import Any, Iterator + + +def _extract_leaf_classes(annotation: Any) -> Iterator[type]: + """Recursively yield all concrete leaf Python classes from a type annotation. + + Unwraps generic aliases (``list[T]``, ``dict[K, V]``, ``Optional[T]``, + ``Union[A, B]``, ``A | B``, etc.) using ``typing.get_origin`` and + ``typing.get_args`` and yields every non-generic leaf found. ``NoneType`` + that appears as a generic argument (from ``Optional`` and + ``Union[..., None]`` / ``T | None``) is skipped — callers see only the + concrete types. When ``type(None)`` is passed directly as the annotation, + it is yielded as-is. + + Non-type, non-generic values (e.g. unresolved string annotations) are + silently skipped. + + Args: + annotation: A Python type or generic alias to inspect. + + Yields: + Concrete Python ``type`` objects found at leaf positions. + + Examples: + >>> list(_extract_leaf_classes(list[int])) + [] + >>> set(_extract_leaf_classes(dict[str, list[MyClass]])) + {, } + """ + origin = typing.get_origin(annotation) + + if origin is None: + # Not a generic alias. Yield only if it is a plain type. + if isinstance(annotation, type): + yield annotation + return + + # Generic alias — recurse into every type argument, skipping NoneType. + for arg in typing.get_args(annotation): + if arg is type(None): + continue + yield from _extract_leaf_classes(arg) + + +def _walk_fqcn(fqcn: str) -> Any: + """Walk a fully-qualified class name and return the resolved object. + + Tries module prefixes from longest to shortest, then walks the remaining + parts as attribute accesses. For example: + + - ``"mypackage.sub.MyClass"`` → import ``mypackage.sub``, then + ``getattr(module, "MyClass")``. + - ``"mypackage.sub.Outer.Inner"`` → import ``mypackage.sub``, then + ``getattr(module, "Outer")``, then ``getattr(Outer, "Inner")``. + + Does **not** validate the type of the resolved object — callers are + responsible for checking that the result is the expected kind of object + (e.g. a dataclass, a ``BaseModel`` subclass). + + Args: + fqcn: Fully-qualified name, e.g. ``"mypackage.sub.MyClass"``. + + Returns: + The resolved Python object. + + Raises: + ImportError: If no valid module+attribute split can be found, or if a + candidate module prefix exists on disk but raises an ``ImportError`` + at import time (e.g. a missing optional dependency). In the latter + case the original exception is re-raised unchanged so callers see + the true root cause. + """ + parts = fqcn.split(".") + if len(parts) < 2: + raise ImportError(f"Cannot import from FQCN {fqcn!r}: no module separator found.") + + for i in range(len(parts) - 1, 0, -1): + module_path = ".".join(parts[:i]) + attr_parts = parts[i:] + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as exc: + # Only continue when the module we tried to import (or a direct + # ancestor of it) simply does not exist. Use an exact-match or + # dotted-prefix check so that a dep whose name is a bare prefix of + # module_path (e.g. dep "path" vs module "pathlib") is not + # accidentally treated as a missing ancestor. + # + # Re-raise in all other cases so callers see the true root cause + # instead of a misleading "no valid module+attribute" error. + if exc.name is None or not ( + exc.name == module_path or module_path.startswith(exc.name + ".") + ): + raise + continue + obj: Any = module + try: + for attr in attr_parts: + obj = getattr(obj, attr) + except AttributeError: + continue + return obj + + raise ImportError( + f"Cannot import from FQCN {fqcn!r}: no valid module+attribute path found." + ) diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index 8055509b..5c4ddc1f 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -3,23 +3,20 @@ Public API ---------- - BaseSemanticHasher -- content-based recursive object hasher (concrete) - SemanticHasherProtocol -- protocol for semantic hashers - TypeHandlerRegistry -- registry mapping types to TypeHandlerProtocol instances - get_default_semantic_hasher -- global default SemanticHasherProtocol factory - get_default_type_handler_registry -- global default TypeHandlerRegistry factory - ContentIdentifiableMixin -- convenience mixin for content-identifiable objects + SemanticAwarePythonHasher -- content-based recursive object hasher + SemanticHasherProtocol -- protocol for semantic hashers + PythonTypeHandlerRegistry -- registry mapping types to PythonTypeHandlerProtocol instances + get_default_semantic_hasher -- global default SemanticHasherProtocol factory + get_default_python_type_handler_registry -- global default registry factory + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects -Built-in handlers (importable for custom registry setup): - PathContentHandler +Built-in hashers (importable for custom registry setup): + PathHandler UUIDHandler BytesHandler FunctionHandler TypeObjectHandler - register_builtin_handlers - -Legacy names (kept for backward compatibility): - HashableMixin -- legacy mixin from legacy_core (deprecated) + register_builtin_python_type_handlers Utility: FileContentHasherProtocol @@ -28,41 +25,40 @@ ArrowHasherProtocol """ -# --------------------------------------------------------------------------- -# New API -- SemanticHasherProtocol, registry, mixin -# --------------------------------------------------------------------------- - -# --------------------------------------------------------------------------- -# Default hasher factories -# --------------------------------------------------------------------------- from orcapod.hashing.defaults import ( get_default_arrow_hasher, + get_default_python_type_handler_registry, get_default_semantic_hasher, - get_default_type_handler_registry, ) - -# --------------------------------------------------------------------------- -# File hashing utilities -# --------------------------------------------------------------------------- from orcapod.hashing.file_hashers import BasicFileHasher, CachedFileHasher from orcapod.hashing.hash_utils import hash_file from orcapod.hashing.semantic_hashing.builtin_handlers import ( BytesHandler, FunctionHandler, - PathContentHandler, + PathHandler, TypeObjectHandler, UUIDHandler, - register_builtin_handlers, + register_builtin_python_type_handlers, ) from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( ContentIdentifiableMixin, ) +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinPythonTypeHandlerRegistry, + PythonTypeHandlerRegistry, +) +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + ContentIdentifiableProtocol, + FileContentHasherProtocol, + FunctionInfoExtractorProtocol, + PythonTypeHandlerProtocol, + SemanticHasherProtocol, + SemanticTypeHasherProtocol, + StringCacherProtocol, +) -# --------------------------------------------------------------------------- -# Legacy API (deprecated -- kept for backward compatibility) -# These imports are guarded because legacy_core.py has pre-existing import -# issues (e.g. references to removed types) that should not block the new API. -# --------------------------------------------------------------------------- try: from orcapod.hashing.legacy_core import ( HashableMixin, @@ -85,60 +81,31 @@ hash_to_hex = None # type: ignore[assignment] hash_to_int = None # type: ignore[assignment] hash_to_uuid = None # type: ignore[assignment] -from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher -from orcapod.hashing.semantic_hashing.type_handler_registry import ( - BuiltinTypeHandlerRegistry, - TypeHandlerRegistry, -) - -# --------------------------------------------------------------------------- -# Protocols (re-exported for convenience) -# --------------------------------------------------------------------------- -from orcapod.protocols.hashing_protocols import ( - ArrowHasherProtocol, - ContentIdentifiableProtocol, - FileContentHasherProtocol, - FunctionInfoExtractorProtocol, - SemanticHasherProtocol, - SemanticTypeHasherProtocol, - StringCacherProtocol, - TypeHandlerProtocol, -) - -# --------------------------------------------------------------------------- -# __all__ -- defines the public surface of this package -# --------------------------------------------------------------------------- __all__ = [ - # ---- New API: concrete implementation ---- - "BaseSemanticHasher", - "TypeHandlerRegistry", - "BuiltinTypeHandlerRegistry", - "get_default_type_handler_registry", + "SemanticAwarePythonHasher", + "PythonTypeHandlerRegistry", + "BuiltinPythonTypeHandlerRegistry", + "get_default_python_type_handler_registry", "get_default_semantic_hasher", "ContentIdentifiableMixin", - # Built-in handlers - "PathContentHandler", + "PathHandler", "UUIDHandler", "BytesHandler", "FunctionHandler", "TypeObjectHandler", - "register_builtin_handlers", - # ---- Protocols ---- + "register_builtin_python_type_handlers", "SemanticHasherProtocol", "ContentIdentifiableProtocol", - "TypeHandlerProtocol", + "PythonTypeHandlerProtocol", "FileContentHasherProtocol", "ArrowHasherProtocol", "StringCacherProtocol", "FunctionInfoExtractorProtocol", "SemanticTypeHasherProtocol", - # ---- File hashing ---- "BasicFileHasher", "CachedFileHasher", "hash_file", - # ---- Legacy / backward-compatible ---- - # TODO: remove legacy section "get_default_arrow_hasher", "HashableMixin", "hash_to_hex", diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index f0931cdf..d5ce6a7c 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -1,297 +1,73 @@ -import hashlib -import json -from collections.abc import Callable -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import pyarrow as pa from starfix import ArrowDigester -from orcapod.hashing import arrow_serialization from orcapod.hashing.schema_cleaner import clean_schema_for_hashing, has_extension_metadata from orcapod.hashing.visitors import SemanticHashingVisitor -from orcapod.semantic_types import SemanticTypeRegistry from orcapod.types import ContentHash -from orcapod.utils import arrow_utils - -SERIALIZATION_METHOD_LUT: dict[str, Callable[[pa.Table], bytes]] = { - "logical": arrow_serialization.serialize_table_logical, -} - - -def json_pyarrow_table_serialization(table: pa.Table) -> str: - """ - Serialize a PyArrow table to a stable JSON string by converting to dictionary of lists. - - Args: - table: PyArrow table to serialize - - Returns: - JSON string representation with sorted keys and no whitespace - """ - # Convert table to dictionary of lists using to_pylist() - data_dict = {} - - for column_name in table.column_names: - # Convert Arrow column to Python list, which visits all elements - data_dict[column_name] = table.column(column_name).to_pylist() - - # Serialize to JSON with sorted keys and no whitespace - return json.dumps( - data_dict, - separators=(",", ":"), - sort_keys=True, - ) - - -class SemanticArrowHasher: - """ - Stable hasher for Arrow tables with semantic type support. - - This hasher: - 1. Uses visitor pattern to recursively process nested data structures - 2. Replaces semantic types with their hash strings using registered converters - 3. Sorts columns by name for deterministic ordering - 4. Uses Arrow serialization for stable binary representation - 5. Computes final hash of the processed table - """ - - def __init__( - self, - semantic_registry: SemanticTypeRegistry, - hasher_id: str | None = None, - hash_algorithm: str = "sha256", - chunk_size: int = 8192, - handle_missing: str = "error", - serialization_method: str = "logical", - # TODO: consider passing options for serialization method - ): - """ - Initialize SemanticArrowHasher. - Args: - semantic_registry: Registry containing semantic type converters with hashing - hash_algorithm: Hash algorithm to use for final table hash - chunk_size: Size of chunks to read files in bytes (legacy, may be removed) - hasher_id: Unique identifier for this hasher instance - handle_missing: How to handle missing files ('error', 'skip', 'null_hash') - serialization_method: Method for serializing Arrow table - """ - if hasher_id is None: - hasher_id = f"semantic_arrow_hasher:{hash_algorithm}:{serialization_method}" - - self._hasher_id = hasher_id - self.semantic_registry = semantic_registry - self.chunk_size = chunk_size - self.handle_missing = handle_missing - self.hash_algorithm = hash_algorithm - - if serialization_method not in SERIALIZATION_METHOD_LUT: - raise ValueError( - f"Invalid serialization method '{serialization_method}'. " - f"Supported methods: {list(SERIALIZATION_METHOD_LUT.keys())}" - ) - self.serialization_method = serialization_method - - @property - def hasher_id(self) -> str: - return self._hasher_id - - def _process_table_columns(self, table: pa.Table | pa.RecordBatch) -> pa.Table: - """ - Process table columns using visitor pattern to handle nested semantic types. - - This replaces the old column-by-column processing with a visitor-based approach - that can handle semantic types nested inside complex data structures. - """ - # TODO: Process in batchwise/chunk-wise fashion for memory efficiency - # Currently using to_pylist() for simplicity but this loads entire table into memory - - new_columns = [] - new_fields = [] - - # Import here to avoid circular dependencies - for i, field in enumerate(table.schema): - # Convert column to struct dicts for processing - column_data = table.column(i).to_pylist() - - # TODO: verify the functioning of the visitor pattern - # Create fresh visitor for each column (stateless approach) - visitor = SemanticHashingVisitor(self.semantic_registry) - - try: - # Use visitor to transform both type and data - new_type = None - processed_data = [] - for c in column_data: - processed_type, processed_value = visitor.visit(field.type, c) - if new_type is None: - new_type = processed_type - processed_data.append(processed_value) - - # Create new Arrow column from processed data - assert new_type is not None, "Failed to infer new column type" - # TODO: revisit this logic - new_column = pa.array(processed_data, type=new_type) - new_field = pa.field(field.name, new_type) - - new_columns.append(new_column) - new_fields.append(new_field) - - except Exception as e: - # Add context about which column failed - raise RuntimeError( - f"Failed to process column '{field.name}': {str(e)}" - ) from e - - # Return new table with processed columns - return pa.table(new_columns, schema=pa.schema(new_fields)) - - def _sort_table_columns(self, table: pa.Table) -> pa.Table: - """Sort table columns by field name for deterministic ordering.""" - # Get sorted column names - sorted_column_names = sorted(table.column_names) - - # Use select to reorder columns - much cleaner! - return table.select(sorted_column_names) - - def serialize_arrow_table(self, table: pa.Table) -> bytes: - """ - Serialize Arrow table using the configured serialization method. - - Args: - table: Arrow table to serialize - - Returns: - Serialized bytes of the table - """ - serialization_method_function = SERIALIZATION_METHOD_LUT[ - self.serialization_method - ] - return serialization_method_function(table) - - def hash_table(self, table: pa.Table | pa.RecordBatch) -> ContentHash: - """ - Compute stable hash of Arrow table with semantic type processing. - - Args: - table: Arrow table to hash - prefix_hasher_id: Whether to prefix hash with hasher ID - - Returns: - Hex string of the computed hash - """ - - # Step 1: Process columns with semantic types using visitor pattern - processed_table = self._process_table_columns(table) - - # Step 2: Sort columns by name for deterministic ordering - sorted_table = self._sort_table_columns(processed_table) - - # normalize all string to large strings (for compatibility with Polars) - normalized_table = arrow_utils.normalize_table_to_large_types(sorted_table) - - # Step 3: Serialize using configured serialization method - serialized_bytes = self.serialize_arrow_table(normalized_table) - - # Step 4: Compute final hash - hasher = hashlib.new(self.hash_algorithm) - hasher.update(serialized_bytes) - - return ContentHash(method=self.hasher_id, digest=hasher.digest()) - - def hash_table_with_metadata(self, table: pa.Table) -> dict[str, Any]: # noqa: C901 - """ - Compute hash with additional metadata about the process. - - Returns: - Dictionary containing hash, metadata, and processing info - """ - # Process table to see what transformations were made - processed_table = self._process_table_columns(table) - - # Track processing steps - processed_columns = [] - for i, (original_field, processed_field) in enumerate( - zip(table.schema, processed_table.schema) - ): - column_info = { - "name": original_field.name, - "original_type": str(original_field.type), - "processed_type": str(processed_field.type), - "was_processed": str(original_field.type) != str(processed_field.type), - } - processed_columns.append(column_info) - - # Compute hash - table_hash = self.hash_table(table) - - return { - "hash": table_hash, - "hasher_id": self.hasher_id, - "serialization_method": self.serialization_method, - "hash_algorithm": self.hash_algorithm, - "num_rows": len(table), - "num_columns": len(table.schema), - "processed_columns": processed_columns, - "column_order": [field.name for field in table.schema], - } +if TYPE_CHECKING: + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + from orcapod.protocols.hashing_protocols import SemanticHasherProtocol class StarfixArrowHasher: - """ - Arrow table hasher backed by the starfix-python ``ArrowDigester``. - - This hasher produces cross-language-compatible, deterministic content - addresses for Arrow tables and schemas by delegating to the canonical - StarFix specification (``starfix-python``). + """Arrow table hasher backed by the starfix-python ``ArrowDigester``. Pipeline -------- 1. **Semantic pre-processing** — the ``SemanticHashingVisitor`` traverses - every column and replaces recognised semantic types (e.g. ``Path`` - structs) with their content-addressed hash strings. This step runs - before the Arrow bytes are ever touched by starfix, so the final hash - captures *file content* for path-typed columns rather than the raw - path string. - 2. **Starfix hashing** — ``ArrowDigester.hash_table`` (or - ``ArrowDigester.hash_schema``) is called on the pre-processed table / - schema. The digester is column-order-independent and normalises - ``Utf8`` → ``LargeUtf8``, ``Binary`` → ``LargeBinary``, etc., - producing a 35-byte versioned SHA-256 digest that is byte-for-byte - identical to the Rust ``starfix`` crate output. + every column. Extension-typed columns whose Python type has a registered + semantic hasher are replaced with ``pa.large_binary()`` hash tokens + (e.g. ``Path`` columns are replaced by their file-content hash). + Extension-typed columns without a registered hasher pass through with + their full extension metadata intact. + 2. **Starfix hashing** — ``ArrowDigester.hash_table`` produces a 35-byte + versioned SHA-256 digest that is byte-for-byte identical to the Rust + ``starfix`` crate output. Parameters ---------- - semantic_registry: - Registry of semantic type converters used during pre-processing. + type_converter: + ``UniversalTypeConverter`` used to resolve extension types to Python + types and convert storage values back to Python objects. + semantic_hasher: + ``SemanticHasherProtocol`` used to hash Python objects extracted + from extension-typed columns. hasher_id: - String identifier embedded in every ``ContentHash`` produced by - this hasher. Bump this value whenever the hash algorithm changes - so that stored hashes remain distinguishable. + String identifier embedded in every ``ContentHash`` produced by this + hasher. """ def __init__( self, - semantic_registry: SemanticTypeRegistry, + type_converter: "UniversalTypeConverter", + semantic_hasher: "SemanticHasherProtocol", hasher_id: str, ) -> None: + self._type_converter = type_converter + self._semantic_hasher = semantic_hasher self._hasher_id = hasher_id - self.semantic_registry = semantic_registry @property def hasher_id(self) -> str: return self._hasher_id - def _process_table_columns(self, table: pa.Table | pa.RecordBatch) -> pa.Table: - """Replace semantic-typed columns with their content-hash strings.""" + def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Table": + """Replace semantic-typed columns with their content-hash bytes.""" new_columns: list[pa.Array] = [] new_fields: list[pa.Field] = [] for i, field in enumerate(table.schema): - # Short-circuit: primitive columns cannot contain semantic types, so skip - # the costly Python round-trip and reuse the original Arrow array directly. + # Short-circuit: columns that cannot contain semantic types skip + # the costly Python round-trip. Extension types must pass through + # so visit_extension can process them. if not ( - pa.types.is_struct(field.type) + isinstance(field.type, pa.ExtensionType) + or pa.types.is_struct(field.type) or pa.types.is_list(field.type) or pa.types.is_large_list(field.type) or pa.types.is_fixed_size_list(field.type) @@ -302,28 +78,20 @@ def _process_table_columns(self, table: pa.Table | pa.RecordBatch) -> pa.Table: continue column_data = table.column(i).to_pylist() - visitor = SemanticHashingVisitor(self.semantic_registry) + visitor = SemanticHashingVisitor(self._type_converter, self._semantic_hasher) try: new_type: pa.DataType | None = None processed_data: list[Any] = [] for value in column_data: processed_type, processed_value = visitor.visit(field.type, value) - # Infer the output type from the first non-null processed value. - # When the first row is null, visit_struct returns the original - # struct type rather than the converted type (e.g. large_string), - # which would cause pa.array() to fail for subsequent non-null rows. if new_type is None and processed_value is not None: new_type = processed_type processed_data.append(processed_value) - # For empty or all-null columns there are no non-null values to infer - # the type from; fall back to the field's declared type. if new_type is None: new_type = field.type new_columns.append(pa.array(processed_data, type=new_type)) - # Preserve original field attributes (nullable, metadata) while - # updating only the type, so the schema fed to starfix remains faithful. new_fields.append(field.with_type(new_type)) except Exception as exc: @@ -331,61 +99,21 @@ def _process_table_columns(self, table: pa.Table | pa.RecordBatch) -> pa.Table: f"Failed to process column '{field.name}': {exc}" ) from exc - # Preserve the original schema-level metadata while using updated fields. - return pa.table(new_columns, schema=pa.schema(new_fields, metadata=table.schema.metadata)) - - def hash_schema(self, schema: pa.Schema) -> ContentHash: - """Hash an Arrow schema using the starfix canonical algorithm. - - ``has_extension_metadata`` is checked first on the raw schema. When - no extension metadata is found, ``include_metadata=False`` is passed - to ``ArrowDigester`` directly without rebuilding the schema (starfix - ignores metadata when ``include_metadata=False``, so the hash is - identical). When extension metadata is present, ``clean_schema_for_hashing`` - strips non-``ARROW:extension:*`` keys before hashing with - ``include_metadata=True``, preserving byte-for-byte hash stability - with pre-v0.3.0 output for extension-free schemas. + return pa.table( + new_columns, + schema=pa.schema(new_fields, metadata=table.schema.metadata), + ) - Parameters - ---------- - schema: - The ``pa.Schema`` to hash. - - Returns - ------- - ContentHash - A ``ContentHash`` whose ``digest`` is the 35-byte versioned - SHA-256 produced by ``ArrowDigester.hash_schema``. - """ + def hash_schema(self, schema: "pa.Schema") -> ContentHash: + """Hash an Arrow schema using the starfix canonical algorithm.""" include_meta = has_extension_metadata(schema) if include_meta: schema = clean_schema_for_hashing(schema) digest = ArrowDigester.hash_schema(schema, include_metadata=include_meta) return ContentHash(method=self._hasher_id, digest=digest) - def hash_table(self, table: pa.Table | pa.RecordBatch) -> ContentHash: - """Hash an Arrow table (or ``RecordBatch``) using starfix. - - Semantic types are resolved to their content-hash strings first. - ``has_extension_metadata`` is then checked on the processed table's - schema. When no extension metadata is found, the processed table is - passed to ``ArrowDigester.hash_table`` directly with - ``include_metadata=False``, avoiding a schema rebuild and new table - allocation. When extension metadata is present, - ``clean_schema_for_hashing`` strips non-``ARROW:extension:*`` keys - before hashing with ``include_metadata=True``. - - Parameters - ---------- - table: - The ``pa.Table`` or ``pa.RecordBatch`` to hash. - - Returns - ------- - ContentHash - A ``ContentHash`` whose ``digest`` is the 35-byte versioned - SHA-256 produced by ``ArrowDigester.hash_table``. - """ + def hash_table(self, table: "pa.Table | pa.RecordBatch") -> ContentHash: + """Hash an Arrow table (or ``RecordBatch``) using starfix.""" if isinstance(table, pa.RecordBatch): table = pa.Table.from_batches([table]) @@ -393,8 +121,6 @@ def hash_table(self, table: pa.Table | pa.RecordBatch) -> ContentHash: include_meta = has_extension_metadata(processed_table.schema) if include_meta: clean_schema = clean_schema_for_hashing(processed_table.schema) - # clean_schema_for_hashing only strips metadata; physical types and - # column order are unchanged, so from_arrays is safe without a cast. clean_table = pa.Table.from_arrays( processed_table.columns, schema=clean_schema ) diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 5dd68ea7..26a6ac44 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -10,20 +10,24 @@ # from its JSON spec. Constructing them here would bypass versioning and # produce hashers that are decoupled from the active data context. -from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeHandlerRegistry from orcapod.protocols import hashing_protocols as hp -def get_default_type_handler_registry() -> TypeHandlerRegistry: +def get_default_python_type_handler_registry() -> PythonTypeHandlerRegistry: """ - Return the TypeHandlerRegistry from the default data context. + Return the ``PythonTypeHandlerRegistry`` from the default data context's + semantic hasher. + + The registry is owned by the active ``SemanticHasherProtocol``, which is itself + versioned inside the active ``DataContext``. Returns: - TypeHandlerRegistry: The type handler registry from the default data context. + PythonTypeHandlerRegistry: The type handler registry from the + default data context. """ from orcapod.contexts import get_default_context - - return get_default_context().type_handler_registry + return get_default_context().semantic_hasher.type_handler_registry def get_default_semantic_hasher() -> hp.SemanticHasherProtocol: @@ -45,46 +49,15 @@ def get_default_semantic_hasher() -> hp.SemanticHasherProtocol: return get_default_context().semantic_hasher -def get_default_arrow_hasher( - cache_file_hash: bool | hp.StringCacherProtocol = True, -) -> hp.ArrowHasherProtocol: - """ - Return the ArrowHasherProtocol from the default data context. - - If ``cache_file_hash`` is True an in-memory StringCacherProtocol is attached to - the hasher so that repeated hashes of the same file path are served from - cache. Pass a ``StringCacherProtocol`` instance to use a custom caching backend - (e.g. SQLite-backed). - - Note: caching is applied on top of the context's arrow hasher each time - this function is called. If you need a single shared cached instance, - obtain it once and store it yourself. +def get_default_arrow_hasher() -> hp.ArrowHasherProtocol: + """Return the ArrowHasherProtocol from the default data context. - Args: - cache_file_hash: True to use an ephemeral in-memory cache, a - StringCacherProtocol instance to use a custom cache, or False/None to - disable caching. + Note: file-hash caching (formerly via ``set_cacher``) has been removed. + ``StarfixArrowHasher`` does not support per-path caching. Use + ``CachedFileHasher`` when constructing a custom context if caching is needed. Returns: - ArrowHasherProtocol: The arrow hasher from the default data context, - optionally with file-hash caching attached. + ArrowHasherProtocol: The arrow hasher from the default data context. """ - from typing import Any - from orcapod.contexts import get_default_context - - arrow_hasher: Any = get_default_context().arrow_hasher - - if cache_file_hash: - from orcapod.hashing.string_cachers import InMemoryCacher - - if cache_file_hash is True: - string_cacher: hp.StringCacherProtocol = InMemoryCacher(max_size=None) - else: - string_cacher = cache_file_hash - - # set_cacher is present on SemanticArrowHasher but not on the - # ArrowHasherProtocol protocol, so we call it via Any to avoid a type error. - arrow_hasher.set_cacher("path", string_cacher) - - return arrow_hasher + return get_default_context().arrow_hasher diff --git a/src/orcapod/hashing/semantic_hashing/__init__.py b/src/orcapod/hashing/semantic_hashing/__init__.py index bc120c18..c8d139b3 100644 --- a/src/orcapod/hashing/semantic_hashing/__init__.py +++ b/src/orcapod/hashing/semantic_hashing/__init__.py @@ -1,20 +1,18 @@ """ orcapod.hashing.semantic_hashing ================================= -Sub-package containing all components of the semantic hashing system: + SemanticAwarePythonHasher -- content-based recursive object hasher + PythonTypeHandlerRegistry -- MRO-aware registry mapping types → PythonTypeHandlerProtocol + BuiltinPythonTypeHandlerRegistry -- pre-populated registry with built-in hashers + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects - BaseSemanticHasher -- content-based recursive object hasher - TypeHandlerRegistry -- MRO-aware registry mapping types → TypeHandlerProtocol - BuiltinTypeHandlerRegistry -- pre-populated registry with built-in handlers - ContentIdentifiableMixin -- convenience mixin for content-identifiable objects - -Built-in TypeHandlerProtocol implementations: - PathContentHandler -- pathlib.Path → file-content hash - UUIDHandler -- uuid.UUID → canonical string - BytesHandler -- bytes/bytearray → hex string - FunctionHandler -- callable → via FunctionInfoExtractorProtocol - TypeObjectHandler -- type objects → "type:." - register_builtin_handlers -- populate a registry with all of the above +Built-in PythonTypeHandlerProtocol implementations: + PathHandler -- pathlib.Path → file-content hash + UUIDHandler -- uuid.UUID → canonical bytes + BytesHandler -- bytes/bytearray → hex string + FunctionHandler -- callable → via FunctionInfoExtractorProtocol + TypeObjectHandler -- type objects → "type:." + register_builtin_python_type_handlers -- populate a registry with all of the above Function info extractors (used by FunctionHandler): FunctionNameExtractor @@ -25,10 +23,10 @@ from orcapod.hashing.semantic_hashing.builtin_handlers import ( BytesHandler, FunctionHandler, - PathContentHandler, + PathHandler, TypeObjectHandler, UUIDHandler, - register_builtin_handlers, + register_builtin_python_type_handlers, ) from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( ContentIdentifiableMixin, @@ -38,28 +36,23 @@ FunctionNameExtractor, FunctionSignatureExtractor, ) -from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher from orcapod.hashing.semantic_hashing.type_handler_registry import ( - BuiltinTypeHandlerRegistry, - TypeHandlerRegistry, + BuiltinPythonTypeHandlerRegistry, + PythonTypeHandlerRegistry, ) __all__ = [ - # Core hasher - "BaseSemanticHasher", - # Registry - "TypeHandlerRegistry", - "BuiltinTypeHandlerRegistry", - # Mixin + "SemanticAwarePythonHasher", + "PythonTypeHandlerRegistry", + "BuiltinPythonTypeHandlerRegistry", "ContentIdentifiableMixin", - # Built-in handlers - "PathContentHandler", + "PathHandler", "UUIDHandler", "BytesHandler", "FunctionHandler", "TypeObjectHandler", - "register_builtin_handlers", - # Function info extractors + "register_builtin_python_type_handlers", "FunctionNameExtractor", "FunctionSignatureExtractor", "FunctionInfoExtractorFactory", diff --git a/src/orcapod/hashing/semantic_hashing/builtin_handlers.py b/src/orcapod/hashing/semantic_hashing/builtin_handlers.py index 1b66d039..68bbb3ec 100644 --- a/src/orcapod/hashing/semantic_hashing/builtin_handlers.py +++ b/src/orcapod/hashing/semantic_hashing/builtin_handlers.py @@ -1,30 +1,20 @@ """ -Built-in TypeHandlerProtocol implementations for the SemanticHasherProtocol system. - -This module provides handlers for all Python types that the SemanticHasherProtocol -knows how to process out of the box: - - - PathContentHandler -- pathlib.Path: returns ContentHash of file content - - UPathContentHandler -- upath.UPath: returns ContentHash of file content (remote-aware) - - UUIDHandler -- uuid.UUID: raw 16-byte binary representation - - BytesHandler -- bytes / bytearray: hex string representation - - FunctionHandler -- callable with __code__: via FunctionInfoExtractorProtocol - - TypeObjectHandler -- type objects (classes): stable "type:" string - -Note: ContentHash requires no handler -- it is recognised as a terminal by -``hash_object`` and returned as-is. - -The module also exposes ``register_builtin_handlers(registry)`` which is -called automatically when the global default registry is first accessed. - -Extending the system --------------------- -To add a handler for a third-party type, create a class that implements the -TypeHandlerProtocol protocol (a single ``handle(obj, hasher)`` method) and register -it: - - from orcapod.hashing.semantic_hashing.type_handler_registry import get_default_type_handler_registry - get_default_type_handler_registry().register(MyType, MyTypeHandler()) +Built-in PythonTypeHandlerProtocol implementations. + + PathHandler -- pathlib.Path: file content hash + UPathHandler -- upath.UPath: file content hash (remote-aware) + UUIDHandler -- uuid.UUID: 16-byte binary representation + BytesHandler -- bytes/bytearray: hex string representation + FunctionHandler -- callable with __code__: via FunctionInfoExtractorProtocol + TypeObjectHandler -- type objects: stable "type:." string + SpecialFormHandler -- typing._SpecialForm + GenericAliasHandler -- generic alias type annotations + UnionTypeHandler -- types.UnionType (Python 3.10+ X | Y syntax) + ArrowTableHandler -- pa.Table / pa.RecordBatch + SchemaHandler -- Schema objects + +``register_builtin_python_type_handlers(registry)`` populates a registry +with all of the above. """ from __future__ import annotations @@ -36,150 +26,96 @@ from upath import UPath -from orcapod.types import PathLike, Schema +from orcapod.types import ContentHash, PathLike, Schema if TYPE_CHECKING: - from orcapod.hashing.semantic_hashing.type_handler_registry import ( - TypeHandlerRegistry, - ) from orcapod.protocols.hashing_protocols import ( ArrowHasherProtocol, FileContentHasherProtocol, + HandlerRegistryProtocol, SemanticHasherProtocol, ) logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Individual handlers -# --------------------------------------------------------------------------- - - -class PathContentHandler: - """ - Handler for pathlib.Path objects. - - Hashes the *content* of the file at the given path using the injected - FileContentHasherProtocol, producing a stable content-addressed identifier. - The resulting bytes are stored as a hex string embedded in the resolved - structure. - - The path must refer to an existing, readable file. Directories and - missing paths are not supported and will raise an error -- if you need - a path-as-string handler, register a separate handler for that use case - or return a ``str`` from ``identity_structure()`` instead of a ``Path``. +class PathHandler: + """Hasher for pathlib.Path objects — hashes file *content*. Args: - file_hasher: Any object with a ``hash_file(path) -> ContentHash`` - method (satisfies the FileContentHasherProtocol protocol). + file_hasher: Any object with a ``hash_file(path) -> ContentHash`` method. """ - def __init__(self, file_hasher: FileContentHasherProtocol) -> None: + def __init__(self, file_hasher: "FileContentHasherProtocol") -> None: self.file_hasher = file_hasher - def handle(self, obj: PathLike, hasher: "SemanticHasherProtocol") -> Any: + def handle(self, obj: PathLike, hasher: "SemanticHasherProtocol") -> ContentHash: path: Path = Path(obj) - if not path.exists(): raise FileNotFoundError( - f"PathContentHandler: path does not exist: {path!r}. " - "Paths must refer to existing files for content-based hashing. " - "If you intended to hash the path string, return str(path) from " - "identity_structure() instead of a Path object." + f"PathHandler: path does not exist: {path!r}. " + "Paths must refer to existing files for content-based hashing." ) - if path.is_dir(): raise IsADirectoryError( - f"PathContentHandler: path is a directory: {path!r}. " + f"PathHandler: path is a directory: {path!r}. " "Only regular files are supported for content-based hashing." ) - - logger.debug("PathContentHandler: hashing file content at %s", path) + logger.debug("PathHandler: hashing file content at %s", path) return self.file_hasher.hash_file(path) -class UPathContentHandler: - """ - Handler for universal_pathlib.UPath objects. - - Behaves identically to ``PathContentHandler`` but preserves the UPath - instance so that remote filesystem semantics (e.g. S3, GCS) are retained - during file content hashing. +class UPathHandler: + """Hasher for universal_pathlib.UPath objects — hashes file content. Args: - file_hasher: Any object with a ``hash_file(path) -> ContentHash`` - method (satisfies the FileContentHasherProtocol protocol). + file_hasher: Any object with a ``hash_file(path) -> ContentHash`` method. """ - def __init__(self, file_hasher: FileContentHasherProtocol) -> None: + def __init__(self, file_hasher: "FileContentHasherProtocol") -> None: self.file_hasher = file_hasher - def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> ContentHash: if not isinstance(obj, UPath): raise TypeError( - f"UPathContentHandler: expected a UPath, got {type(obj)!r}. " - "Use PathContentHandler for pathlib.Path objects." + f"UPathHandler: expected a UPath, got {type(obj)!r}." ) - if not obj.exists(): raise FileNotFoundError( - f"UPathContentHandler: path does not exist: {obj!r}. " - "Paths must refer to existing files for content-based hashing." + f"UPathHandler: path does not exist: {obj!r}." ) - if obj.is_dir(): raise IsADirectoryError( - f"UPathContentHandler: path is a directory: {obj!r}. " - "Only regular files are supported for content-based hashing." + f"UPathHandler: path is a directory: {obj!r}." ) - - logger.debug("UPathContentHandler: hashing file content at %s", obj) + logger.debug("UPathHandler: hashing file content at %s", obj) return self.file_hasher.hash_file(obj) class UUIDHandler: - """Handler for ``uuid.UUID`` objects. - - Returns the raw 16-byte binary representation of the UUID. - The binary form is compact, unambiguous, and independent of string - formatting conventions. UUID values in data columns are stored as - ``pa.binary(16)`` (fixed-size) within the struct type used by - ``UUIDStructConverter``; database record IDs use ``pa.large_binary()``. - """ + """Hasher for ``uuid.UUID`` objects — returns the raw 16-byte binary representation.""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: return obj.bytes class BytesHandler: - """ - Handler for bytes and bytearray objects. - - Converts binary data to its lowercase hex string representation. This - avoids JSON serialisation issues with raw bytes while preserving the - exact byte sequence in the hash input. - """ + """Hasher for bytes and bytearray objects — returns the lowercase hex string.""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: if isinstance(obj, (bytes, bytearray)): return obj.hex() - raise TypeError(f"BytesHandler: expected bytes or bytearray, got {type(obj)!r}") + raise TypeError( + f"BytesHandler: expected bytes or bytearray, got {type(obj)!r}" + ) class FunctionHandler: - """ - Handler for Python functions / callables that carry a ``__code__`` attribute. - - Delegates to a FunctionInfoExtractorProtocol to produce a stable, serialisable - dict representation of the function. The extractor is responsible for - deciding which parts of the function (name, signature, source body, etc.) - are included. + """Hasher for Python functions/callables with a ``__code__`` attribute. Args: function_info_extractor: Any object with an - ``extract_function_info(func) -> dict`` method (satisfies the - FunctionInfoExtractorProtocol protocol). + ``extract_function_info(func) -> dict`` method. """ def __init__(self, function_info_extractor: Any) -> None: @@ -197,12 +133,9 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: class TypeObjectHandler: - """ - Handler for type objects (i.e. classes passed as values). + """Hasher for type objects (classes passed as values). - Returns a stable string of the form ``"type:."`` so - that different classes always produce different hash inputs and the - result is human-readable. + Returns a stable string of the form ``"type:."``. """ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: @@ -216,16 +149,7 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: class SpecialFormHandler: - """ - Handler for ``typing._SpecialForm`` objects such as ``typing.Union`` and - ``typing.ClassVar``. - - These appear as the ``__origin__`` of typing generics — for example, - ``Optional[int]`` is ``Union[int, None]``, whose ``__origin__`` is - ``typing.Union``. Returns a stable string of the form - ``"special_form:typing."`` so they can be safely embedded as the - origin component inside a ``GenericAliasHandler`` result. - """ + """Hasher for ``typing._SpecialForm`` objects such as ``typing.Union``.""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: name = getattr(obj, "_name", None) or repr(obj) @@ -233,19 +157,7 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: class GenericAliasHandler: - """ - Handler for generic alias type annotations such as ``dict[int, list[int]]`` - (``types.GenericAlias``) and ``typing`` generics (``typing._GenericAlias``). - - Produces a stable dict containing the origin type and a list of hashed - argument types so that structurally identical generic annotations always - yield the same hash, and structurally different ones yield different hashes. - - When the origin is ``typing.Union`` (i.e. ``typing.Optional[X]`` or - ``typing.Union[X, Y]``), the handler produces a canonical ``"union"`` - form with sorted args — identical to `UnionTypeHandler` — so that - ``typing.Optional[int]`` and ``int | None`` hash equivalently. - """ + """Hasher for generic alias type annotations (``dict[int, str]``, ``Optional[X]``, etc.).""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: import typing @@ -254,16 +166,9 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: args = getattr(obj, "__args__", None) or () if origin is None: return f"generic_alias:{obj!r}" - - # Normalize typing.Union / typing.Optional to the canonical union - # form so that typing.Optional[int] ≡ typing.Union[int, None] ≡ int | None. if origin is typing.Union: hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) - return { - "__type__": "union", - "args": hashed_args, - } - + return {"__type__": "union", "args": hashed_args} return { "__type__": "generic_alias", "origin": hasher.hash_object(origin).to_string(), @@ -272,45 +177,34 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: class UnionTypeHandler: - """ - Handler for ``types.UnionType`` objects (Python 3.10+ ``X | Y`` syntax). - - ``str | None``, ``int | float``, etc. produce a ``types.UnionType`` at - runtime, which is distinct from ``typing.Union[str, None]`` - (a ``typing._GenericAlias``). This handler normalises union types into - a canonical ``"union"`` form with sorted args — identical to the union - branch in `GenericAliasHandler` — so that ``int | None``, - ``typing.Optional[int]``, and ``typing.Union[int, None]`` all hash - equivalently. - """ + """Hasher for ``types.UnionType`` objects (Python 3.10+ ``X | Y`` syntax).""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: args = getattr(obj, "__args__", None) or () hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) - return { - "__type__": "union", - "args": hashed_args, - } + return {"__type__": "union", "args": hashed_args} class ArrowTableHandler: - """ - Handler for ``pa.Table`` and ``pa.RecordBatch`` objects. - - Delegates to the injected ``ArrowHasherProtocol`` to produce a stable, - content-addressed ``ContentHash`` of the Arrow table data. The returned - ``ContentHash`` is recognised as a terminal by ``hash_object`` and - returned as-is — no further recursion occurs. + """Hasher for ``pa.Table`` and ``pa.RecordBatch`` objects. Args: - arrow_hasher: Any object satisfying ArrowHasherProtocol (i.e. has a - ``hash_table(table) -> ContentHash`` method). + arrow_hasher: Any object satisfying ``ArrowHasherProtocol``. When + ``None``, the default data context's ``arrow_hasher`` is resolved + lazily at call time (breaking the circular dependency that would + arise if the registry were constructed before the arrow hasher). """ - def __init__(self, arrow_hasher: ArrowHasherProtocol) -> None: - self.arrow_hasher = arrow_hasher + def __init__(self, arrow_hasher: "ArrowHasherProtocol | None" = None) -> None: + self._arrow_hasher = arrow_hasher - def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + def _get_arrow_hasher(self) -> "ArrowHasherProtocol": + if self._arrow_hasher is not None: + return self._arrow_hasher + from orcapod.contexts import get_default_context + return get_default_context().arrow_hasher # type: ignore[return-value] + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> ContentHash: import pyarrow as _pa if isinstance(obj, _pa.RecordBatch): @@ -319,159 +213,91 @@ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: raise TypeError( f"ArrowTableHandler: expected pa.Table or pa.RecordBatch, got {type(obj)!r}" ) - return self.arrow_hasher.hash_table(obj) + return self._get_arrow_hasher().hash_table(obj) class SchemaHandler: - """ - Handler for `Schema` objects. - - Produces a stable dict containing both the field-type mapping and the - sorted list of optional field names, so that two schemas differing only - in which fields are optional produce different hashes. - """ + """Hasher for ``Schema`` objects.""" def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: if not isinstance(obj, Schema): - raise TypeError(f"SchemaHandler: expected a Schema, got {type(obj)!r}") - # schema handler is not implemented yet - raise NotImplementedError() - # visited: frozenset[int] = frozenset() - - # return { - # "fields": {k: hasher._expand_element(v, visited) for k, v in obj.items()}, - # "optional_fields": sorted(obj.optional_fields), - # } - - -# --------------------------------------------------------------------------- -# Registration helper -# --------------------------------------------------------------------------- + raise TypeError( + f"SchemaHandler: expected a Schema, got {type(obj)!r}" + ) + raise NotImplementedError("SchemaHandler is not yet implemented.") -def register_builtin_handlers( - registry: "TypeHandlerRegistry", +def register_builtin_python_type_handlers( + registry: "HandlerRegistryProtocol", file_hasher: Any = None, function_info_extractor: Any = None, arrow_hasher: "ArrowHasherProtocol | None" = None, ) -> None: - """ - Register all built-in TypeHandlers into *registry*. - - This function is called automatically when the global default registry is - first accessed via ``get_default_type_handler_registry()``. It can also - be called manually to populate a custom registry. + """Register all built-in semantic hashers into *registry*. - Path, function, and Arrow table handling require auxiliary objects. - When these are not supplied, sensible defaults are constructed: - - - ``BasicFileHasher`` (SHA-256, 64 KiB buffer) for Path handling. - - ``FunctionSignatureExtractor`` for function handling. - - ``SemanticArrowHasher`` (SHA-256, logical serialisation) for Arrow table handling. + ``pa.Table`` and ``pa.RecordBatch`` are always registered via + ``ArrowTableHandler``. When ``arrow_hasher`` is provided it is + passed through for immediate use; when ``None``, ``ArrowTableHandler`` + resolves the active arrow hasher lazily via ``get_default_context()`` at + hash time, breaking the construction-time circular dependency. Args: - registry: - The TypeHandlerRegistry to populate. - file_hasher: - Optional object satisfying FileContentHasherProtocol (i.e. has a - ``hash_file(path) -> ContentHash`` method). Defaults to a - ``BasicFileHasher`` configured with SHA-256. - function_info_extractor: - Optional object satisfying FunctionInfoExtractorProtocol (i.e. has an - ``extract_function_info(func) -> dict`` method). Defaults to - ``FunctionSignatureExtractor``. - arrow_hasher: - Optional object satisfying ArrowHasherProtocol (i.e. has a - ``hash_table(table) -> ContentHash`` method). Defaults to a - ``SemanticArrowHasher`` configured with SHA-256 and logical serialisation. - Should be the data context's arrow hasher when called from a versioned - context so that hashing is consistent across all components. + registry: The ``HandlerRegistryProtocol`` instance to populate. + file_hasher: Optional ``FileContentHasherProtocol`` for path hashing. + Defaults to ``BasicFileHasher(sha256)``. + function_info_extractor: Optional ``FunctionInfoExtractorProtocol``. + Defaults to ``FunctionSignatureExtractor``. + arrow_hasher: Optional ``ArrowHasherProtocol`` for nested table hashing. + When ``None``, lazy resolution via the default context is used. """ - # Resolve defaults for auxiliary objects ---------------------------- if file_hasher is None: - from orcapod.hashing.file_hashers import BasicFileHasher # stays in hashing/ - + from orcapod.hashing.file_hashers import BasicFileHasher file_hasher = BasicFileHasher(algorithm="sha256") if function_info_extractor is None: from orcapod.hashing.semantic_hashing.function_info_extractors import ( FunctionSignatureExtractor, ) - function_info_extractor = FunctionSignatureExtractor( include_module=True, include_defaults=True, ) - if arrow_hasher is None: - from orcapod.hashing.arrow_hashers import SemanticArrowHasher - from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry + bytes_hasher = BytesHandler() + registry.register(bytes, bytes_hasher) + registry.register(bytearray, bytes_hasher) - arrow_hasher = SemanticArrowHasher( - semantic_registry=SemanticTypeRegistry(), - hasher_id="arrow_v0.1", - hash_algorithm="sha256", - serialization_method="logical", - ) - - # Register handlers ------------------------------------------------- - - # bytes / bytearray - bytes_handler = BytesHandler() - registry.register(bytes, bytes_handler) - registry.register(bytearray, bytes_handler) - - # pathlib.Path (and subclasses such as PosixPath / WindowsPath) - registry.register(Path, PathContentHandler(file_hasher)) - - # uuid.UUID + registry.register(Path, PathHandler(file_hasher)) + registry.register(UPath, UPathHandler(file_hasher)) registry.register(UUID, UUIDHandler()) - # Note: ContentHash needs no handler -- SemanticHasherProtocol treats it as - # a terminal in hash_object() and returns it as-is. - - # Functions -- register types.FunctionType so MRO lookup works for - # plain ``def`` functions, plus built-in functions and bound methods. import types as _types - function_handler = FunctionHandler(function_info_extractor) - registry.register(_types.FunctionType, function_handler) - registry.register(_types.BuiltinFunctionType, function_handler) - registry.register(_types.MethodType, function_handler) + function_hasher = FunctionHandler(function_info_extractor) + registry.register(_types.FunctionType, function_hasher) + registry.register(_types.BuiltinFunctionType, function_hasher) + registry.register(_types.MethodType, function_hasher) - # type objects (classes used as values, e.g. passed in a dict) registry.register(type, TypeObjectHandler()) - - # types.UnionType (Python 3.10+ X | Y syntax, e.g. str | None) registry.register(_types.UnionType, UnionTypeHandler()) - # generic alias type annotations: dict[int, str], list[str], etc. - generic_alias_handler = GenericAliasHandler() - registry.register(_types.GenericAlias, generic_alias_handler) - # typing._GenericAlias covers Optional[X], Union[X, Y], Dict[K, V], etc. - # typing._SpecialForm covers typing.Union, typing.ClassVar, etc. which - # appear as __origin__ on those generics (e.g. Optional[int].__origin__ - # is typing.Union, a _SpecialForm). + generic_alias_hasher = GenericAliasHandler() + registry.register(_types.GenericAlias, generic_alias_hasher) try: import typing as _typing - - registry.register(_typing._GenericAlias, generic_alias_handler) # type: ignore[attr-defined] + registry.register(_typing._GenericAlias, generic_alias_hasher) # type: ignore[attr-defined] registry.register(_typing._SpecialForm, SpecialFormHandler()) # type: ignore[attr-defined] except AttributeError: pass - # Schema objects -- must come after type handler so Schema is matched - # specifically rather than falling through to the Mapping expansion path registry.register(Schema, SchemaHandler()) - # Arrow tables and record batches -- delegate to the injected arrow hasher import pyarrow as _pa - - arrow_table_handler = ArrowTableHandler(arrow_hasher) - registry.register(_pa.Table, arrow_table_handler) - registry.register(_pa.RecordBatch, arrow_table_handler) + arrow_table_hasher = ArrowTableHandler(arrow_hasher) + registry.register(_pa.Table, arrow_table_hasher) + registry.register(_pa.RecordBatch, arrow_table_hasher) logger.debug( - "register_builtin_handlers: registered %d built-in handlers", + "register_builtin_python_type_handlers: registered %d hashers", len(registry), ) diff --git a/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py b/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py index f4bd04ce..4543ff01 100644 --- a/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py +++ b/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py @@ -3,14 +3,14 @@ Any class that implements ``identity_structure()`` can inherit from this mixin to gain a full suite of content-based identity helpers without having to wire -up a BaseSemanticHasher manually: +up a ``SemanticHasherProtocol`` manually: - ``content_hash()`` -- returns a stable ContentHash for the object - ``__hash__()`` -- Python hash based on content (int) - ``__eq__()`` -- equality via content_hash comparison -The mixin uses the global default BaseSemanticHasher by default, but accepts an -injected hasher for testing or custom configurations. +The mixin uses the global default ``SemanticHasherProtocol`` by default, but +accepts an injected hasher for testing or custom configurations. Usage ----- @@ -32,7 +32,8 @@ def identity_structure(self): With an injected hasher (e.g. in tests):: - hasher = BaseSemanticHasher(hasher_id="test", strict=True) + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher + hasher = SemanticAwarePythonHasher(hasher_id="test", strict=True) record = MyRecord("foo", 42) record._semantic_hasher = hasher print(record.content_hash()) @@ -65,7 +66,7 @@ def identity_structure(self): import logging from typing import Any -from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.protocols.hashing_protocols import SemanticHasherProtocol from orcapod.types import ContentHash logger = logging.getLogger(__name__) @@ -82,19 +83,19 @@ def identity_structure(self) -> Any: ... The returned structure is recursively resolved and hashed by the - BaseSemanticHasher to produce a stable ContentHash. + ``SemanticHasherProtocol`` to produce a stable ContentHash. Parameters (passed as keyword arguments to ``__init__``) --------------------------------------------------------- semantic_hasher: - Optional BaseSemanticHasher instance to use. When omitted, the hasher - is obtained from the default data context via + Optional ``SemanticHasherProtocol`` instance to use. When omitted, + the hasher is obtained from the default data context via ``orcapod.contexts.get_default_context().semantic_hasher``, which is the single source of truth for versioned component configuration. """ def __init__( - self, *, semantic_hasher: BaseSemanticHasher | None = None, **kwargs: Any + self, *, semantic_hasher: SemanticHasherProtocol | None = None, **kwargs: Any ) -> None: # Cooperative MRO-friendly init -- forward remaining kwargs up the chain. super().__init__(**kwargs) @@ -215,9 +216,8 @@ def _invalidate_content_hash_cache(self) -> None: # Hasher resolution # ------------------------------------------------------------------ - def _get_hasher(self) -> BaseSemanticHasher: - """ - Return the BaseSemanticHasher to use for this object. + def _get_hasher(self) -> SemanticHasherProtocol: + """Return the ``SemanticHasherProtocol`` to use for this object. Resolution order: 1. The instance-level ``_semantic_hasher`` attribute (set at @@ -230,7 +230,7 @@ def _get_hasher(self) -> BaseSemanticHasher: type converter, etc.) that belong to the same context. Returns: - BaseSemanticHasher: The hasher to use. + SemanticHasherProtocol: The hasher to use. """ if self._semantic_hasher is not None: return self._semantic_hasher diff --git a/src/orcapod/hashing/semantic_hashing/semantic_hasher.py b/src/orcapod/hashing/semantic_hashing/semantic_hasher.py index ceb13315..2235037c 100644 --- a/src/orcapod/hashing/semantic_hashing/semantic_hasher.py +++ b/src/orcapod/hashing/semantic_hashing/semantic_hasher.py @@ -1,5 +1,5 @@ """ -BaseSemanticHasher -- content-based recursive object hasher. +SemanticAwarePythonHasher -- content-based recursive object hasher. Algorithm --------- @@ -13,7 +13,9 @@ - Primitive → JSON-serialise + SHA-256 - Structure → delegate to ``_expand_structure``, then JSON-serialise the resulting tagged tree + SHA-256 - - Handler match → call handler.handle(obj), recurse via hash_object + - Semantic hasher match → handler.handle(obj, self) returns a representative + Python structure (or ContentHash as terminal); the result + is fed back into hash_object for final hashing - ContentIdentifiableProtocol→ call identity_structure(), recurse via hash_object - Fallback → strict error or best-effort string, then hash @@ -69,7 +71,6 @@ from collections.abc import Callable, Mapping from typing import Any -from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry from orcapod.protocols import hashing_protocols as hp from orcapod.types import ContentHash @@ -79,7 +80,7 @@ _MEMADDR_RE = re.compile(r" at 0x[0-9a-fA-F]+") -class BaseSemanticHasher: +class SemanticAwarePythonHasher: """ Content-based recursive hasher. @@ -89,8 +90,9 @@ class BaseSemanticHasher: A short string identifying this hasher version/configuration. Embedded in every ContentHash produced. type_handler_registry: - TypeHandlerRegistry for MRO-aware lookup of TypeHandlerProtocol instances. - If None, the default registry from the active DataContext is used. + ``HandlerRegistryProtocol`` for MRO-aware lookup of + ``PythonTypeHandlerProtocol`` instances. + If None, the default registry is used. strict: When True (default) raises TypeError for unhandled types. When False falls back to a best-effort string representation. @@ -99,16 +101,15 @@ class BaseSemanticHasher: def __init__( self, hasher_id: str, - type_handler_registry: TypeHandlerRegistry | None = None, + type_handler_registry: "hp.HandlerRegistryProtocol | None" = None, strict: bool = True, ) -> None: self._hasher_id = hasher_id self._strict = strict if type_handler_registry is None: - from orcapod.hashing.defaults import get_default_type_handler_registry - - self._registry = get_default_type_handler_registry() # stays in hashing/ + from orcapod.hashing.defaults import get_default_python_type_handler_registry + self._registry = get_default_python_type_handler_registry() else: self._registry = type_handler_registry @@ -124,6 +125,11 @@ def hasher_id(self) -> str: def strict(self) -> bool: return self._strict + @property + def type_handler_registry(self) -> "hp.HandlerRegistryProtocol": + """Return the ``HandlerRegistryProtocol`` used by this hasher.""" + return self._registry + def hash_object( self, obj: Any, @@ -138,7 +144,8 @@ def hash_object( - ContentHash → terminal; returned as-is - Primitive → JSON-serialised and hashed directly - Structure → structurally expanded then hashed - - Handler match → handler produces a value, recurse + - Semantic hasher match → handler.handle(obj, self) returns a representative Python + structure (or ContentHash); result is fed back into hash_object for final hashing - ContentIdentifiableProtocol→ resolver(obj) if resolver provided, else obj.content_hash() - Unknown type → TypeError in strict mode; best-effort otherwise @@ -169,7 +176,9 @@ def hash_object( ) return self._hash_to_content_hash(expanded) - # Handler dispatch: the handler produces a new value; recurse. + # Semantic hasher dispatch: handler returns a representative Python structure + # (or a ContentHash as terminal); feed the result back into hash_object so + # that returning a plain structure is equivalent to calling hash_object on it. handler = self._registry.get_handler(obj) if handler is not None: logger.debug( @@ -177,7 +186,8 @@ def hash_object( type(obj).__name__, type(handler).__name__, ) - return self.hash_object(handler.handle(obj, self), resolver=resolver) + result = handler.handle(obj, self) + return self.hash_object(result, resolver=resolver) # ContentIdentifiableProtocol: use resolver if provided, else content_hash(). if isinstance(obj, hp.ContentIdentifiableProtocol): @@ -354,9 +364,9 @@ def _hash_to_content_hash(self, obj: Any) -> ContentHash: ).encode("utf-8") except (TypeError, ValueError) as exc: raise TypeError( - f"BaseSemanticHasher: failed to JSON-serialise object of type " - f"{type(obj).__name__!r}. Ensure all TypeHandlers and " - "identity_structure() implementations return JSON-serialisable " + f"SemanticAwarePythonHasher: failed to JSON-serialise object of type " + f"{type(obj).__name__!r}. Ensure all PythonTypeHandlerProtocol " + "implementations and identity_structure() return JSON-serialisable " "primitives or structures." ) from exc @@ -378,14 +388,16 @@ def _handle_unknown(self, obj: Any) -> str: if self._strict: raise TypeError( - f"BaseSemanticHasher (strict): no TypeHandlerProtocol registered for type " - f"'{qualified}' and it does not implement ContentIdentifiableProtocol. " - "Register a TypeHandlerProtocol via the TypeHandlerRegistry or implement " - "identity_structure() on the class." + f"SemanticAwarePythonHasher (strict): no implementation of " + f"PythonTypeHandlerProtocol registered for type '{qualified}' and it " + "does not implement ContentIdentifiableProtocol. Register an " + "implementation of PythonTypeHandlerProtocol via the " + "HandlerRegistryProtocol or implement identity_structure() on the class." ) logger.warning( - "SemanticHasherProtocol (non-strict): no handler for type '%s'. " + "SemanticAwarePythonHasher (non-strict): no implementation of " + "PythonTypeHandlerProtocol registered for type '%s'. " "Falling back to best-effort string representation.", qualified, ) diff --git a/src/orcapod/hashing/semantic_hashing/type_handler_registry.py b/src/orcapod/hashing/semantic_hashing/type_handler_registry.py index 690ec024..6389b501 100644 --- a/src/orcapod/hashing/semantic_hashing/type_handler_registry.py +++ b/src/orcapod/hashing/semantic_hashing/type_handler_registry.py @@ -1,23 +1,8 @@ """ -Type Handler Registry for the SemanticHasherProtocol system. +PythonTypeHandlerRegistry — MRO-aware registry for PythonTypeHandlerProtocol instances. -Provides a registry through which TypeHandlerProtocol implementations can be -registered for specific Python types. Lookup is MRO-aware: if no handler -is registered for an exact type, the registry walks the MRO of the object's -class to find the nearest ancestor for which a handler has been registered. - -Usage ------ -# Register a handler for a specific type: -registry = TypeHandlerRegistry() -registry.register(Path, PathContentHandler()) - -# Or use the global default registry: -from orcapod.hashing.semantic_hashing.type_handler_registry import get_default_type_handler_registry -get_default_type_handler_registry().register(MyType, MyTypeHandler()) - -# Look up a handler (returns None if not found): -handler = registry.get_handler(some_object) +``PythonTypeHandlerProtocol`` is the protocol for type-specific handlers; this registry +provides MRO-aware lookup so subclasses inherit their parent's handler. """ from __future__ import annotations @@ -29,21 +14,18 @@ class to find the nearest ancestor for which a handler has been registered. if TYPE_CHECKING: from orcapod.protocols.hashing_protocols import ( ArrowHasherProtocol, - TypeHandlerProtocol, + PythonTypeHandlerProtocol, ) logger = logging.getLogger(__name__) -class TypeHandlerRegistry: - """ - Registry mapping Python types to TypeHandlerProtocol instances. +class PythonTypeHandlerRegistry: + """Registry mapping Python types to PythonTypeHandlerProtocol instances. - Lookup is MRO-aware: when no handler is registered for the exact type of + Lookup is MRO-aware: when no hasher is registered for the exact type of an object, the registry walks the object's MRO (most-derived first) until - it finds a match. This means a handler registered for a base class is - automatically inherited by all subclasses, unless a more specific handler - has been registered for the subclass. + it finds a match. Thread safety ------------- @@ -52,42 +34,28 @@ class TypeHandlerRegistry: """ def __init__( - self, handlers: list[tuple[type, TypeHandlerProtocol]] | None = None + self, handlers: list[tuple[type, "PythonTypeHandlerProtocol"]] | None = None ) -> None: """ Args: - handlers: Optional list of ``(target_type, handler)`` pairs to - register at construction time. Designed for use with - ``parse_objectspec``: the JSON spec provides a list of - two-element arrays where the first element uses ``_type`` - to resolve a Python type and the second uses ``_class`` to - instantiate the handler. + handlers: Optional list of ``(target_type, hasher)`` pairs to + register at construction time. """ - # Maps type -> handler; insertion order is preserved but lookup uses MRO. - self._handlers: dict[type, TypeHandlerProtocol] = {} + self._handlers: dict[type, "PythonTypeHandlerProtocol"] = {} self._lock = threading.RLock() if handlers: for target_type, handler in handlers: self.register(target_type, handler) - # ------------------------------------------------------------------ - # Registration - # ------------------------------------------------------------------ + def register(self, target_type: type, handler: "PythonTypeHandlerProtocol") -> None: + """Register a hasher for a specific Python type. - def register(self, target_type: type, handler: TypeHandlerProtocol) -> None: - """ - Register a handler for a specific Python type. - - If a handler is already registered for *target_type*, it is silently - replaced by the new handler. + If a hasher is already registered for *target_type*, it is silently + replaced by the new hasher. Args: - target_type: The Python type (or class) for which the handler - should be used. Must be a ``type`` object. - handler: A TypeHandlerProtocol instance whose ``handle()`` method will - be called when an object of ``target_type`` (or a - subclass with no more specific handler) is encountered - during structure resolution. + target_type: The Python type (or class) for which the hasher should be used. + handler: A ``PythonTypeHandlerProtocol`` instance. Raises: TypeError: If ``target_type`` is not a ``type``. @@ -100,7 +68,7 @@ def register(self, target_type: type, handler: TypeHandlerProtocol) -> None: existing = self._handlers.get(target_type) if existing is not None and existing is not handler: logger.debug( - "TypeHandlerRegistry: replacing existing handler for %s (%s -> %s)", + "PythonTypeHandlerRegistry: replacing existing hasher for %s (%s -> %s)", target_type.__name__, type(existing).__name__, type(handler).__name__, @@ -108,14 +76,13 @@ def register(self, target_type: type, handler: TypeHandlerProtocol) -> None: self._handlers[target_type] = handler def unregister(self, target_type: type) -> bool: - """ - Remove the handler registered for *target_type*, if any. + """Remove the hasher registered for *target_type*, if any. Args: - target_type: The type whose handler should be removed. + target_type: The type whose hasher should be removed. Returns: - True if a handler was removed, False if none was registered. + True if a hasher was removed, False if none was registered. """ with self._lock: if target_type in self._handlers: @@ -123,59 +90,41 @@ def unregister(self, target_type: type) -> bool: return True return False - # ------------------------------------------------------------------ - # Lookup - # ------------------------------------------------------------------ - - def get_handler(self, obj: Any) -> "TypeHandlerProtocol | None": - """ - Look up the handler for *obj* using MRO-aware resolution. - - The MRO of ``type(obj)`` is walked from most-derived to least-derived - (i.e. the object's own class first, then its bases). The first - match found in the registry is returned. + def get_handler(self, obj: Any) -> "PythonTypeHandlerProtocol | None": + """Look up the handler for *obj* using MRO-aware resolution. Args: obj: The object for which a handler is needed. Returns: - The registered TypeHandlerProtocol, or None if no handler is registered - for the object's type or any of its base classes. + The registered ``PythonTypeHandlerProtocol``, or None. """ obj_type = type(obj) with self._lock: - # Fast path: exact type match. handler = self._handlers.get(obj_type) if handler is not None: return handler - - # Slow path: walk the MRO, skipping the type itself (already - # checked above) and skipping ``object`` as a last resort -- a - # handler registered for ``object`` would match everything. for base in obj_type.__mro__[1:]: handler = self._handlers.get(base) if handler is not None: logger.debug( - "TypeHandlerRegistry: resolved handler for %s via base %s", + "PythonTypeHandlerRegistry: resolved hasher for %s via base %s", obj_type.__name__, base.__name__, ) return handler - return None - def get_handler_for_type(self, target_type: type) -> "TypeHandlerProtocol | None": - """ - Look up the handler for a *type object* (rather than an instance). - - Useful when the caller already has the type and wants to check - registration without constructing a dummy instance. + def get_handler_for_type( + self, target_type: type + ) -> "PythonTypeHandlerProtocol | None": + """Look up the handler for a *type object* (rather than an instance). Args: target_type: The type to look up. Returns: - The registered TypeHandlerProtocol, or None. + The registered ``PythonTypeHandlerProtocol``, or None. """ with self._lock: handler = self._handlers.get(target_type) @@ -188,9 +137,7 @@ def get_handler_for_type(self, target_type: type) -> "TypeHandlerProtocol | None return None def has_handler(self, target_type: type) -> bool: - """ - Return True if a handler is registered for *target_type* or any of - its MRO ancestors. + """Return True if a handler is registered for *target_type* or any MRO ancestor. Args: target_type: The type to check. @@ -198,63 +145,43 @@ def has_handler(self, target_type: type) -> bool: return self.get_handler_for_type(target_type) is not None def registered_types(self) -> list[type]: - """ - Return a list of all directly-registered types (no MRO expansion). - - Returns: - A snapshot list of types that have explicit handler registrations. - """ + """Return a list of all directly-registered types (no MRO expansion).""" with self._lock: return list(self._handlers.keys()) - # ------------------------------------------------------------------ - # Dunder helpers - # ------------------------------------------------------------------ - def __repr__(self) -> str: with self._lock: names = [t.__name__ for t in self._handlers] - return f"TypeHandlerRegistry(registered={names!r})" + return f"PythonTypeHandlerRegistry(registered={names!r})" def __len__(self) -> int: with self._lock: return len(self._handlers) -# --------------------------------------------------------------------------- -# Pre-populated registry -# --------------------------------------------------------------------------- - - -def get_default_type_handler_registry() -> "TypeHandlerRegistry": - """ - Return the TypeHandlerRegistry from the default data context. +def get_default_python_type_handler_registry() -> "PythonTypeHandlerRegistry": + """Return the PythonTypeHandlerRegistry from the default data context. This is a convenience wrapper; the registry is owned and versioned by the - active DataContext. Importing this function from + active ``DataContext``. Importing this function from ``orcapod.hashing.defaults`` or ``orcapod.hashing`` is equivalent. """ from orcapod.hashing.defaults import ( - get_default_type_handler_registry as _get, - ) # stays in hashing/ - + get_default_python_type_handler_registry as _get, + ) return _get() -class BuiltinTypeHandlerRegistry(TypeHandlerRegistry): - """ - A TypeHandlerRegistry pre-populated with all built-in handlers. +class BuiltinPythonTypeHandlerRegistry(PythonTypeHandlerRegistry): + """A PythonTypeHandlerRegistry pre-populated with all built-in hashers. Constructed via the data context JSON spec so that the default registry - is versioned alongside the rest of the context components. The built-in - handlers are registered in ``__init__`` so that no separate population - step is required after construction. + is versioned alongside the rest of the context components. """ def __init__(self, arrow_hasher: "ArrowHasherProtocol | None" = None) -> None: super().__init__() from orcapod.hashing.semantic_hashing.builtin_handlers import ( - register_builtin_handlers, + register_builtin_python_type_handlers, ) - - register_builtin_handlers(self, arrow_hasher=arrow_hasher) + register_builtin_python_type_handlers(self, arrow_hasher=arrow_hasher) diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py index 1e7b7255..c968bbca 100644 --- a/src/orcapod/hashing/versioned_hashers.py +++ b/src/orcapod/hashing/versioned_hashers.py @@ -14,14 +14,13 @@ recursive hasher that replaces BasicObjectHasher). get_versioned_semantic_arrow_hasher() - Return the current-version SemanticArrowHasher (Arrow table hasher - with semantic-type support). + Return the current-version StarfixArrowHasher (Arrow table hasher + with extension-type semantic support). """ from __future__ import annotations import logging -from typing import Any from orcapod.protocols import hashing_protocols as hp @@ -49,50 +48,36 @@ def get_versioned_semantic_hasher( hasher_id: str = _CURRENT_SEMANTIC_HASHER_ID, strict: bool = True, - type_handler_registry: "hp.TypeHandlerRegistry | None" = None, # type: ignore[name-defined] + type_handler_registry: "hp.HandlerRegistryProtocol | None" = None, ) -> hp.SemanticHasherProtocol: - """ - Return a SemanticHasherProtocol configured for the current version. - - The returned hasher uses the global default TypeHandlerRegistry (which - is pre-populated with all built-in handlers) unless an explicit registry - is supplied. + """Return a SemanticHasherProtocol configured for the current version. Parameters ---------- hasher_id: Identifier embedded in every ContentHash produced by this hasher. - Defaults to the current version constant. Override only when - producing hashes that must be tagged with a specific version string. strict: - When True (the default) the hasher raises TypeError on encountering - an object of an unhandled type. When False it falls back to a - best-effort string representation with a logged warning. + When True raises TypeError for unhandled types. When False falls back + to a best-effort string representation. type_handler_registry: - Optional TypeHandlerRegistry to inject. When None the global - default registry is used (recommended for production code). - - Returns - ------- - SemanticHasherProtocol - A fully configured SemanticHasherProtocol instance. + Optional ``HandlerRegistryProtocol`` to inject. When None the + global default registry is used. """ - from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher if type_handler_registry is None: from orcapod.hashing.semantic_hashing.type_handler_registry import ( - get_default_type_handler_registry, + get_default_python_type_handler_registry, ) - - type_handler_registry = get_default_type_handler_registry() + type_handler_registry = get_default_python_type_handler_registry() logger.debug( - "get_versioned_semantic_hasher: creating BaseSemanticHasher " + "get_versioned_semantic_hasher: creating SemanticAwarePythonHasher " "(hasher_id=%r, strict=%r)", hasher_id, strict, ) - return BaseSemanticHasher( + return SemanticAwarePythonHasher( hasher_id=hasher_id, type_handler_registry=type_handler_registry, strict=strict, @@ -100,55 +85,30 @@ def get_versioned_semantic_hasher( # --------------------------------------------------------------------------- -# SemanticArrowHasher factory +# StarfixArrowHasher factory # --------------------------------------------------------------------------- def get_versioned_semantic_arrow_hasher( hasher_id: str = _CURRENT_ARROW_HASHER_ID, ) -> hp.ArrowHasherProtocol: - """ - Return a SemanticArrowHasher configured for the current version. - - The arrow hasher handles Arrow table / RecordBatch hashing with - semantic-type awareness (e.g. Path columns are hashed by file content). + """Return a StarfixArrowHasher configured for the current version. - Parameters - ---------- - hasher_id: - Identifier embedded in every ContentHash produced by this hasher. - - Returns - ------- - ArrowHasherProtocol - A fully configured SemanticArrowHasher instance. + Sources ``type_converter`` and ``semantic_hasher`` from the default + ``DataContext`` so that the arrow hasher is consistent with all other + versioned components. """ from orcapod.hashing.arrow_hashers import StarfixArrowHasher - from orcapod.hashing.file_hashers import BasicFileHasher - from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry - from orcapod.semantic_types.semantic_struct_converters import ( - PythonPathStructConverter, - UUIDStructConverter, - ) - - # Build a default semantic registry populated with the standard converters. - # We use Any-typed locals here to side-step type-checker false positives - # that arise from the protocol definition of SemanticStructConverterProtocol having - # a slightly different hash_struct_dict signature than the concrete class. - registry: Any = SemanticTypeRegistry() - file_hasher = BasicFileHasher(algorithm="sha256") - path_converter: Any = PythonPathStructConverter(file_hasher=file_hasher) - registry.register_converter("path", path_converter) - uuid_converter: Any = UUIDStructConverter() - registry.register_converter("uuid", uuid_converter) + from orcapod.contexts import resolve_context + ctx = resolve_context(None) # default context logger.debug( "get_versioned_semantic_arrow_hasher: creating StarfixArrowHasher " "(hasher_id=%r)", hasher_id, ) - hasher: Any = StarfixArrowHasher( + return StarfixArrowHasher( hasher_id=hasher_id, - semantic_registry=registry, + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, ) - return hasher diff --git a/src/orcapod/hashing/visitors.py b/src/orcapod/hashing/visitors.py index f3a6fe50..ec0382ac 100644 --- a/src/orcapod/hashing/visitors.py +++ b/src/orcapod/hashing/visitors.py @@ -1,79 +1,105 @@ """ -SUGGESTED FILE: src/orcapod/hashing/visitors.py - Generic visitor pattern for traversing Arrow types and data simultaneously. - -This provides a base visitor class that can be extended for various processing needs -like semantic hashing, validation, data cleaning, etc. """ from __future__ import annotations from abc import ABC, abstractmethod +import typing from typing import TYPE_CHECKING, Any -from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + from orcapod.protocols.hashing_protocols import SemanticHasherProtocol else: pa = LazyModule("pyarrow") class ArrowTypeDataVisitor(ABC): - """ - Base visitor for traversing Arrow types and data simultaneously. - - This enables processing that needs to transform both the Arrow schema - and the corresponding data in a single pass. - """ + """Base visitor for traversing Arrow types and data simultaneously.""" @abstractmethod def visit_struct( self, struct_type: "pa.StructType", data: dict | None ) -> tuple["pa.DataType", Any]: - """Visit a struct type with its data""" + """Visit a struct type with its data.""" pass @abstractmethod def visit_list( self, list_type: "pa.ListType", data: list | None ) -> tuple["pa.DataType", Any]: - """Visit a list type with its data""" + """Visit a list type with its data.""" pass @abstractmethod def visit_map( self, map_type: "pa.MapType", data: dict | None ) -> tuple["pa.DataType", Any]: - """Visit a map type with its data""" + """Visit a map type with its data.""" pass @abstractmethod def visit_primitive( self, primitive_type: "pa.DataType", data: Any ) -> tuple["pa.DataType", Any]: - """Visit a primitive type with its data""" + """Visit a primitive type with its data.""" pass - def visit(self, arrow_type: "pa.DataType", data: Any) -> tuple["pa.DataType", Any]: + def visit_extension( + self, + extension_type: "pa.ExtensionType", + storage_value: Any, + ) -> tuple["pa.DataType", Any]: + """Handle an Arrow extension type. + + Default implementation: passthrough — preserves the extension type and its + storage value unchanged so that the downstream ``StarfixArrowHasher`` / + ``ArrowDigester`` sees the full extension metadata when it receives the + pre-processed table. + + Subclasses may override to convert recognised extension types to a hashed + ``pa.large_binary()`` value. + + Args: + extension_type: The Arrow extension type. + storage_value: The storage-level value (result of ``to_pylist()`` on the column). + + Returns: + Tuple of ``(new_arrow_type, new_data)``. """ - Main dispatch method that routes to appropriate visit method. + return extension_type, storage_value + + def visit(self, arrow_type: "pa.DataType", data: Any) -> tuple["pa.DataType", Any]: + """Main dispatch method that routes to the appropriate visit method. + + Extension types are checked **first** — before the struct check — because + extension types with struct storage would otherwise be incorrectly routed + into ``visit_struct``. After ``visit_extension``, the result is re-visited + only if the type changed AND is no longer an extension type (enables + composability, avoids infinite recursion). Args: - arrow_type: Arrow data type to process - data: Corresponding data value + arrow_type: Arrow data type to process. + data: Corresponding data value. Returns: - Tuple of (new_arrow_type, new_data) + Tuple of ``(new_arrow_type, new_data)``. """ + if isinstance(arrow_type, pa.ExtensionType): + new_type, new_data = self.visit_extension(arrow_type, data) + if new_type is not arrow_type and not isinstance(new_type, pa.ExtensionType): + return self.visit(new_type, new_data) + return new_type, new_data + if pa.types.is_struct(arrow_type): return self.visit_struct(arrow_type, data) elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): return self.visit_list(arrow_type, data) elif pa.types.is_fixed_size_list(arrow_type): - # Treat fixed-size lists like regular lists for processing return self.visit_list(arrow_type, data) elif pa.types.is_map(arrow_type): return self.visit_map(arrow_type, data) @@ -82,12 +108,8 @@ def visit(self, arrow_type: "pa.DataType", data: Any) -> tuple["pa.DataType", An def _visit_struct_fields( self, struct_type: "pa.StructType", data: dict | None - ) -> tuple["pa.StructType", dict]: - """ - Helper method to recursively process struct fields. - - This is the default behavior for regular (non-semantic) structs. - """ + ) -> tuple["pa.StructType", dict | None]: + """Recursively process struct fields. Default behavior for regular structs.""" if data is None: return struct_type, None @@ -97,7 +119,6 @@ def _visit_struct_fields( for field in struct_type: field_data = data.get(field.name) new_field_type, new_field_data = self.visit(field.type, field_data) - new_fields.append(pa.field(field.name, new_field_type)) new_data[field.name] = new_field_data @@ -105,12 +126,8 @@ def _visit_struct_fields( def _visit_list_elements( self, list_type: "pa.ListType", data: list | None - ) -> tuple["pa.DataType", list]: - """ - Helper method to recursively process list elements. - - This is the default behavior for lists. - """ + ) -> tuple["pa.DataType", list | None]: + """Recursively process list elements.""" if data is None: return list_type, None @@ -121,16 +138,12 @@ def _visit_list_elements( for item in data: current_element_type, processed_item = self.visit(element_type, item) processed_elements.append(processed_item) - - # Use the first non-None element to determine new element type - if new_element_type is None: + if new_element_type is None and processed_item is not None: new_element_type = current_element_type - # If list was empty or all None, keep original element type if new_element_type is None: new_element_type = element_type - # Create appropriate list type based on original type if pa.types.is_large_list(list_type): return pa.large_list(new_element_type), processed_elements elif pa.types.is_fixed_size_list(list_type): @@ -140,77 +153,97 @@ def _visit_list_elements( class SemanticHashingError(Exception): - """Exception raised when semantic hashing fails""" - + """Exception raised when semantic hashing fails.""" pass class SemanticHashingVisitor(ArrowTypeDataVisitor): + """Visitor that replaces extension-typed columns with their content hashes. + + For each Arrow column whose type is a ``pa.ExtensionType``: + + 1. Look up the corresponding Python type via ``type_converter``. + 2. If the Python type has a semantic hasher registered in ``python_hasher``, + convert the storage value to a Python object and hash it, replacing the + column with a ``pa.large_binary()`` value of the form:: + + + b"::" + content_hash.to_prefixed_digest() + + where ``type_name`` is the extension name with dots replaced by colons + (e.g. ``"orcapod.path"`` → ``"orcapod:path"``), and + ``to_prefixed_digest()`` = ``method_bytes + b":" + digest``. + 3. If no hasher is registered (or the converter doesn't know the type), + return the extension type and storage value unchanged. The downstream + ``StarfixArrowHasher`` / ``ArrowDigester`` will see the full extension + metadata intact and hash it in a type-aware way. + + Args: + type_converter: The active ``UniversalTypeConverter`` for resolving + extension type → Python type and storage → Python conversion. + python_hasher: The active ``SemanticHasherProtocol`` for hashing + Python objects. """ - Visitor that replaces semantic types with their hash strings. - This visitor traverses Arrow type structures and data simultaneously, - identifying semantic types by their struct signatures and replacing - them with hash strings computed by their respective converters. - """ - - def __init__(self, semantic_registry: SemanticTypeRegistry): - """ - Initialize the semantic hashing visitor. - - Args: - semantic_registry: Registry containing semantic type converters - """ - self.registry = semantic_registry + def __init__( + self, + type_converter: "UniversalTypeConverter", + python_hasher: "SemanticHasherProtocol", + ) -> None: + self._type_converter = type_converter + self._python_hasher = python_hasher self._current_field_path: list[str] = [] + def visit_extension( + self, + extension_type: "pa.ExtensionType", + storage_value: Any, + ) -> tuple["pa.DataType", Any]: + """Hash an extension type value to pa.large_binary(), or passthrough.""" + if storage_value is None: + return extension_type, None + + # Resolve extension type → Python type. + python_type = self._type_converter.arrow_type_to_python_type(extension_type) + + # If the converter couldn't resolve to a concrete class, passthrough. + if python_type is typing.Any or not isinstance(python_type, type): + return extension_type, storage_value + + # Only hash if a semantic hasher is registered for this Python type. + if not self._python_hasher.type_handler_registry.has_handler( + python_type + ): + return extension_type, storage_value + + # Convert storage value → Python object and hash it. + python_obj = self._type_converter.storage_to_python(storage_value, python_type) + content_hash = self._python_hasher.hash_object(python_obj) + + # Encode as binary: ":::" + # Dots in the extension name → colons (e.g. "orcapod.path" → "orcapod:path"). + # The "::" separator is unambiguous because to_prefixed_digest() uses only ":". + type_name = extension_type.extension_name.replace(".", ":") + hash_bytes = ( + type_name.encode("utf-8") + + b"::" + + content_hash.to_prefixed_digest() + ) + return pa.large_binary(), hash_bytes + def visit_struct( self, struct_type: "pa.StructType", data: dict | None ) -> tuple["pa.DataType", Any]: - """ - Visit a struct type, checking if it's a semantic type. - - If the struct is a semantic type (recognized by signature), replace it - with a hash string. Otherwise, recursively process its fields. - """ + """Regular struct (no extension identity) — recurse into fields.""" if data is None: return struct_type, None - - # Check if this struct IS a semantic type by signature recognition - converter = self.registry.get_converter_for_struct_signature(struct_type) - if converter: - # This is a semantic type - hash it - try: - hash_string = converter.hash_struct_dict(data) - return pa.large_string(), hash_string - except Exception as e: - field_path = ( - ".".join(self._current_field_path) - if self._current_field_path - else "" - ) - converter_name = getattr( - converter, "semantic_type_name", str(type(converter).__name__) - ) - raise SemanticHashingError( - f"Failed to hash semantic type '{converter_name}' at field path '{field_path}': {str(e)}" - ) from e - else: - # Regular struct - recursively process fields - return self._visit_struct_fields(struct_type, data) + return self._visit_struct_fields(struct_type, data) def visit_list( self, list_type: "pa.ListType", data: list | None ) -> tuple["pa.DataType", Any]: - """ - Visit a list type, recursively processing elements. - - Elements that are semantic types will be replaced with hash strings. - """ + """Recurse into list elements.""" if data is None: return list_type, None - - # Add list indicator to field path for error context self._current_field_path.append("[*]") try: return self._visit_list_elements(list_type, data) @@ -220,28 +253,19 @@ def visit_list( def visit_map( self, map_type: "pa.MapType", data: dict | None ) -> tuple["pa.DataType", Any]: - """ - Visit a map type. - - For now, we treat maps as pass-through since they're less common. - TODO: Implement proper map traversal if needed for semantic types in keys/values. - """ + """Pass map types through unchanged.""" return map_type, data def visit_primitive( self, primitive_type: "pa.DataType", data: Any ) -> tuple["pa.DataType", Any]: - """ - Visit a primitive type - pass through unchanged. - - Primitive types cannot be semantic types (which are always structs). - """ + """Pass primitive types through unchanged.""" return primitive_type, data def _visit_struct_fields( self, struct_type: "pa.StructType", data: dict | None - ) -> tuple["pa.StructType", dict]: - """Override to add field path tracking for better error messages""" + ) -> tuple["pa.StructType", dict | None]: + """Override to add field path tracking for better error messages.""" if data is None: return struct_type, None @@ -249,12 +273,10 @@ def _visit_struct_fields( new_data = {} for field in struct_type: - # Add field name to path for error context self._current_field_path.append(field.name) try: field_data = data.get(field.name) new_field_type, new_field_data = self.visit(field.type, field_data) - new_fields.append(pa.field(field.name, new_field_type)) new_data[field.name] = new_field_data finally: diff --git a/src/orcapod/protocols/hashing_protocols.py b/src/orcapod/protocols/hashing_protocols.py index 3ab2aace..a5e066d4 100644 --- a/src/orcapod/protocols/hashing_protocols.py +++ b/src/orcapod/protocols/hashing_protocols.py @@ -17,203 +17,119 @@ class DataContextAwareProtocol(Protocol): @property def data_context_key(self) -> str: - """ - Return the data context key associated with this object. - - Returns: - str: The data context key - """ + """Return the data context key associated with this object.""" ... @runtime_checkable class PipelineElementProtocol(Protocol): - """ - Protocol for objects that have a stable identity as an element in a - pipeline graph — determined by schema and upstream topology, not by - data content. - - This is a parallel identity chain to ContentIdentifiableProtocol. - Where content identity captures the precise, data-inclusive identity of - an object, pipeline identity captures only what is structurally meaningful - for pipeline database path scoping: the schemas and the recursive topology - of the upstream computation. - - The base case (RootSource) returns a hash of (tag_schema, data_schema). - Every other element recurses through the pipeline_hash() of its upstream - inputs, with the hash values themselves (ContentHash objects) used as - terminal leaves so no special hasher mode is required. - - Two sources with identical schemas processed through the same function pod - graph will produce the same pipeline_hash() at every downstream node, - enabling automatic multi-source table sharing in the pipeline database. - """ + """Protocol for objects that have a stable identity as an element in a pipeline graph.""" def pipeline_identity_structure(self) -> Any: - """ - Return a structure representing this element's pipeline identity. - - At source nodes (base case): return (tag_schema, data_schema). - At all other nodes: return a structure containing references to - upstream pipeline elements and/or data functions as raw objects. - The pipeline resolver threaded through pipeline_hash() ensures that - PipelineElementProtocol objects are resolved via pipeline_hash() and - other ContentIdentifiable objects via content_hash(), both using the - same hasher throughout the computation. - """ + """Return a structure representing this element's pipeline identity.""" ... def pipeline_hash(self, hasher=None) -> ContentHash: - """ - Return the pipeline-level hash of this element, computed from - pipeline_identity_structure() and cached by hasher_id. - - Args: - hasher: Optional semantic hasher to use. When omitted, resolved - from the element's data_context. - """ + """Return the pipeline-level hash of this element.""" ... @runtime_checkable class ContentIdentifiableProtocol(Protocol): - """ - Protocol for objects that can express their semantic identity as a plain - Python structure. - - This is the only method a class needs to implement to participate in the - content-based hashing system. The returned structure is recursively - resolved by the SemanticHasherProtocol -- any nested ContentIdentifiableProtocol objects - within the structure will themselves be expanded and hashed, producing a - Merkle-tree-like composition of hashes. - - The method should return a deterministic structure whose value depends - only on the semantic content of the object -- not on memory addresses, - object IDs, or other incidental runtime state. - """ + """Protocol for objects that can express their semantic identity as a plain Python structure.""" def identity_structure(self) -> Any: - """ - Return a structure that represents the semantic identity of this object. - - The returned value may be any Python object: - - Primitives (str, int, float, bool, None) are used as-is. - - Collections (list, dict, set, tuple) are recursively traversed. - - Nested ContentIdentifiableProtocol objects are recursively resolved by - the SemanticHasherProtocol: their identity structure is hashed to a - ContentHash hex token, which is then embedded in place of the - object in the parent structure. - - Any type that has a registered TypeHandlerProtocol in the - SemanticHasherProtocol's registry is handled by that handler. + """Return a structure that represents the semantic identity of this object.""" + ... - Returns: - Any: A structure representing this object's semantic content. - Should be deterministic and include all identity-relevant data. - """ + def content_hash(self, hasher: "SemanticHasherProtocol | None" = None) -> ContentHash: + """Returns the content hash.""" ... - def content_hash(self, hasher: SemanticHasherProtocol | None = None) -> ContentHash: - """ - Returns the content hash. - Args: - hasher: Optional semantic hasher to use for the entire recursive - computation. When omitted, resolved from the object's - data_context (or injected hasher for mixin-based objects). - The same hasher propagates to all nested ContentIdentifiable - objects, ensuring one consistent context per computation. - """ - ... +class PythonTypeHandlerProtocol(Protocol): + """Protocol for type-specific semantic hashers used by ``SemanticAwarePythonHasher``. + A ``PythonTypeHandlerProtocol`` converts a specific Python type into a + representative Python structure that ``SemanticHasherProtocol.hash_object()`` + can then hash. Implementations are registered with a + ``HandlerRegistryProtocol`` and looked up via MRO-aware resolution. -class TypeHandlerProtocol(Protocol): - """ - Protocol for type-specific serialization handlers used by SemanticHasherProtocol. - - A TypeHandlerProtocol converts a specific Python type into a value that - ``hash_object`` can process. Handlers are registered with a - TypeHandlerRegistry and looked up via MRO-aware resolution. - - The returned value is passed directly back to ``hash_object``, so it may - be anything that ``hash_object`` understands: - - - A primitive (None, bool, int, float, str) -- hashed directly. - - A structure (list, tuple, dict, set, frozenset) -- expanded and hashed. - - A ContentHash -- treated as a terminal; returned as-is without - re-hashing. Use this when the handler has already computed the - definitive hash of the object (e.g. hashing a file's content). - - A ContentIdentifiableProtocol -- its identity_structure() will be called. - - Another registered type -- dispatched through the registry. + Each implementation receives the full ``SemanticHasherProtocol`` so it can + delegate hashing of sub-values back to the outer hasher without coupling to a + specific hasher instance. """ def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: - """ - Convert *obj* into a value that ``hash_object`` can process. + """Return a representative Python structure for *obj*. + + The returned value is passed back into + ``SemanticHasherProtocol.hash_object()`` for final hashing. Returning + a ``ContentHash`` short-circuits the process: the caller returns it as-is + without re-hashing. This is useful for handlers that compute content-based + hashes from external data (e.g. file content, Arrow tables). Args: - obj: The object to handle. - hasher: The SemanticHasherProtocol, available if the handler needs to - hash sub-objects explicitly via ``hasher.hash_object()``. + obj: The object to hash. Always matches the registered type. + hasher: The active ``SemanticHasherProtocol``. Use + ``hasher.hash_object(sub_value)`` to hash sub-values that + require type-specific treatment. Returns: - Any value accepted by ``hash_object``: a primitive, structure, - ContentHash, ContentIdentifiableProtocol, or another registered type. + A representative Python structure (primitive, dict, list, bytes, etc.) + that will be passed into ``hash_object()`` for final hashing, or a + ``ContentHash`` to terminate hashing immediately. """ ... -class SemanticHasherProtocol(Protocol): +class HandlerRegistryProtocol(Protocol): + """Protocol for type handler registries used by ``SemanticHasherProtocol``. + + Abstracts over ``PythonTypeHandlerRegistry`` so that ``SemanticHasherProtocol`` + and its consumers do not depend on the concrete registry class. """ - Protocol for the semantic content-based hasher. - ``hash_object(obj)`` is the single recursive entry point. It produces a - ContentHash for any Python object using the following dispatch: + def register(self, target_type: type, handler: "PythonTypeHandlerProtocol") -> None: + """Register a handler for a specific Python type.""" + ... + + def get_handler(self, obj: Any) -> "PythonTypeHandlerProtocol | None": + """Look up the handler for *obj* using MRO-aware resolution.""" + ... + + def get_handler_for_type(self, target_type: type) -> "PythonTypeHandlerProtocol | None": + """Look up the handler for a type object (rather than an instance).""" + ... + + def has_handler(self, target_type: type) -> bool: + """Return True if a handler is registered for *target_type* or any MRO ancestor.""" + ... - - ContentHash → terminal; returned as-is - - Primitive → JSON-serialised and hashed directly - - Structure → structurally expanded (type-tagged), then hashed - - Handler match → handler.handle() returns a new value; recurse - - ContentIdentifiableProtocol→ identity_structure() returns a value; recurse - - Unknown → TypeError (strict) or best-effort string (lenient) + def __len__(self) -> int: + """Return the number of directly-registered types.""" + ... - Containers are type-tagged before hashing so that list, tuple, dict, set, - and namedtuple produce distinct hashes even when their elements are equal. - Unknown types raise TypeError by default (strict mode). Set - strict=False on construction to fall back to a best-effort string - representation with a warning instead. - """ +class SemanticHasherProtocol(Protocol): + """Protocol for the semantic content-based hasher.""" def hash_object( self, obj: Any, resolver: Callable[[Any], ContentHash] | None = None, ) -> ContentHash: - """ - Hash *obj* based on its semantic content. - - Args: - obj: The object to hash. - resolver: Optional callable invoked for any ContentIdentifiable - object encountered during hashing. When provided it overrides - the default obj.content_hash() call, allowing the caller to - control which identity chain is used and to propagate a - consistent hasher through the full recursive computation. - - Returns: - ContentHash: Stable, content-based hash of the object. - """ + """Hash *obj* based on its semantic content.""" ... @property def hasher_id(self) -> str: - """ - Returns a unique identifier/name for this hasher instance. + """Returns a unique identifier/name for this hasher instance.""" + ... - The hasher_id is embedded in every ContentHash produced by this - hasher, allowing hashes from different versions or configurations - to be distinguished. - """ + @property + def type_handler_registry(self) -> HandlerRegistryProtocol: + """Return the handler registry used by this hasher.""" ... @@ -263,11 +179,8 @@ def hasher_id(self) -> str: """Unique identifier for this semantic type hasher.""" ... - def hash_column( - self, - column: "pa.Array", - ) -> "pa.Array": - """Hash a column with this semantic type and return the hash bytes an an array""" + def hash_column(self, column: "pa.Array") -> "pa.Array": + """Hash a column with this semantic type and return the hash bytes as an array.""" ... def set_cacher(self, cacher: StringCacherProtocol) -> None: diff --git a/src/orcapod/protocols/semantic_types_protocols.py b/src/orcapod/protocols/semantic_types_protocols.py index 002e2686..f2303190 100644 --- a/src/orcapod/protocols/semantic_types_protocols.py +++ b/src/orcapod/protocols/semantic_types_protocols.py @@ -51,54 +51,6 @@ def get_arrow_to_python_converter( self, arrow_type: "pa.DataType" ) -> "Callable[[Any], Any]": ... + def ensure_types_registered_for_schemas(self, *schemas: Schema) -> None: ... -# Core protocols -class SemanticStructConverterProtocol(Protocol): - """Protocol for converting between Python objects and semantic structs.""" - @property - def python_type(self) -> DataType: - """The Python type this converter can handle.""" - ... - - @property - def arrow_struct_type(self) -> "pa.StructType": - """The Arrow struct type this converter produces.""" - ... - - def python_to_struct_dict(self, value: Any) -> dict[str, Any]: - """Convert Python value to struct dictionary.""" - ... - - def struct_dict_to_python(self, struct_dict: dict[str, Any]) -> Any: - """Convert struct dictionary back to Python value.""" - ... - - def can_handle_python_type(self, python_type: DataType) -> bool: - """Check if this converter can handle the given Python type.""" - ... - - def can_handle_struct_type(self, struct_type: "pa.StructType") -> bool: - """Check if this converter can handle the given struct type.""" - ... - - def hash_struct_dict(self, struct_dict: dict[str, Any]) -> str: - """ - Compute hash of the semantic type from its struct dictionary representation. - - Args: - struct_dict: Arrow struct dictionary representation - - Returns: - Hash string of the form ``"{type}:sha256:"``, - e.g. ``"path:sha256:abc123"`` - - Raises: - Exception: If hashing fails (e.g., file not found for path types) - """ - ... - - @property - def hasher_id(self) -> str: - """Identifier for this hasher (for debugging/versioning)""" - ... diff --git a/src/orcapod/semantic_types/__init__.py b/src/orcapod/semantic_types/__init__.py index 123777f5..f7948ee7 100644 --- a/src/orcapod/semantic_types/__init__.py +++ b/src/orcapod/semantic_types/__init__.py @@ -1,9 +1,7 @@ -from .semantic_registry import SemanticTypeRegistry from .universal_converter import UniversalTypeConverter from .type_inference import infer_python_schema_from_pylist_data __all__ = [ - "SemanticTypeRegistry", "UniversalTypeConverter", "infer_python_schema_from_pylist_data", ] diff --git a/src/orcapod/semantic_types/dataclass_encoding.py b/src/orcapod/semantic_types/dataclass_encoding.py deleted file mode 100644 index 13467a1a..00000000 --- a/src/orcapod/semantic_types/dataclass_encoding.py +++ /dev/null @@ -1,366 +0,0 @@ -# src/orcapod/semantic_types/dataclass_encoding.py -""" -Dataclass <-> Arrow struct encoding for Orcapod. - -Encodes Python dataclasses as Arrow structs with a ``__dataclass.`` sentinel -field carrying the fully-qualified class name. Decoding uses a three-tier -fallback: import -> registry -> synthesize. -""" - -from __future__ import annotations - -import dataclasses -import importlib -import logging -import os -import re -import sys -import typing -from typing import TYPE_CHECKING, Any - -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -logger = logging.getLogger(__name__) - -DATACLASS_TYPE_FIELD = "__dataclass." -DATACLASS_TYPE_PREFIX = "dataclass:" - -# Validates fully-qualified class names like "my_module.sub.MyClass". -# Also accepts qualnames containing "" segments produced by local -# class definitions (e.g. "mod.func..MyClass"). Each dot-separated -# segment may be a normal identifier or the literal token "". -_FQCN_RE = re.compile(r"^[A-Za-z_]\w*(\.[A-Za-z_]\w*|\.)+$") - -# Matches all identifier tokens within a stringified annotation. -# Used by _get_type_hints_safe to handle compound forms like -# "Optional[_Inner]", "list[_Inner]", or "_Inner | None". -_IDENT_RE = re.compile(r"[A-Za-z_]\w*") - -# Process-global registry for tier-2 reconstruction. -# Populated via register_dataclass(); persists for the process lifetime. -_DATACLASS_REGISTRY: dict[str, type] = {} - -# Tier-1 import gate. -# Set ORCAPOD_DATACLASS_IMPORT=0 to disable importlib-based reconstruction, -# e.g. in environments where arbitrary module import from on-disk __type values -# is not acceptable. Tier-2 (registry) and tier-3 (synthesize) still work. -_TIER1_IMPORT_ENABLED: bool = os.environ.get("ORCAPOD_DATACLASS_IMPORT", "1") != "0" - - -def register_dataclass(cls: type) -> type: - """Register a dataclass for tier-2 reconstruction by fully-qualified name. - - Can be used as a class decorator or called directly. Returns ``cls`` - unchanged so it works transparently as a decorator. - - Args: - cls: A Python dataclass type to register. - - Returns: - The same ``cls`` that was passed in. - - Raises: - TypeError: If ``cls`` is not a dataclass type. - """ - if not dataclasses.is_dataclass(cls) or not isinstance(cls, type): - raise TypeError(f"{cls!r} is not a dataclass type") - key = f"{cls.__module__}.{cls.__qualname__}" - _DATACLASS_REGISTRY[key] = cls - return cls - - -def has_dataclass_type_sentinel(arrow_type: pa.DataType) -> bool: - """Return ``True`` if ``arrow_type`` is a struct with a ``__dataclass.`` string field. - - Accepts both ``pa.large_string()`` and ``pa.string()`` for compatibility - with data written by older Arrow versions. - - Args: - arrow_type: Any PyArrow data type. - - Returns: - True if ``arrow_type`` is a struct containing a - ``__dataclass.: (large_)string`` field. - """ - if not pa.types.is_struct(arrow_type): - return False - for i in range(arrow_type.num_fields): - field = arrow_type.field(i) - if field.name == DATACLASS_TYPE_FIELD: - return pa.types.is_large_string(field.type) or pa.types.is_string(field.type) - return False - - -def _get_type_hints_safe(cls: type) -> dict[str, Any]: - """Return type hints for a dataclass, tolerating unresolvable local annotations. - - Calls ``typing.get_type_hints(cls)`` first. If that raises ``NameError`` - (which happens for classes with annotations that reference locally-scoped - types when ``from __future__ import annotations`` is in effect), falls - back to searching call-stack frames for the identifier tokens referenced - in the annotations, then to module globals, and finally returns raw string - annotations as a last resort. - - The token scan (via ``_IDENT_RE``) extracts *all* identifiers from each - string annotation, so compound forms like ``"Optional[_Inner]"``, - ``"list[_Inner]"``, and ``"_Inner | None"`` are handled correctly — only - matching the whole annotation string would miss them. - - Frame traversal uses ``sys._getframe()``/``f_back`` rather than - ``inspect.stack()`` to avoid the overhead and strong-reference pitfalls - introduced by ``inspect.stack()``'s ``FrameInfo`` wrapper objects. - - Args: - cls: A Python dataclass type. - - Returns: - A dict mapping field names to resolved type hints. Values may be string - annotations for names that could not be resolved. - """ - try: - return typing.get_type_hints(cls) - except NameError: - pass - - localns: dict[str, Any] = {} - - # 1. Module globals for the class's module (cheap, no frame traversal needed). - module = sys.modules.get(cls.__module__) - if module is not None: - for name, obj in vars(module).items(): - if isinstance(obj, type): - localns[name] = obj - - # 2. Collect *all* identifier tokens from string annotations so that compound - # forms like "Optional[_Inner]" or "_Inner | None" are handled correctly. - raw_annotations = cls.__annotations__ - token_names: set[str] = set() - for v in raw_annotations.values(): - if isinstance(v, str): - token_names.update(_IDENT_RE.findall(v)) - - # 3. Walk the live frame chain via f_back — no FrameInfo objects, no extra - # strong references to frames. - if token_names: - frame = sys._getframe(0) - while frame is not None: - remaining = token_names - set(localns) - if not remaining: - break - for name in remaining: - obj = frame.f_locals.get(name) - if obj is not None and isinstance(obj, type): - localns[name] = obj - frame = frame.f_back - - try: - return typing.get_type_hints(cls, localns=localns) - except NameError: - pass - - # Last resort: return raw annotations (may contain strings for local types). - return dict(raw_annotations) - - -def dataclass_to_arrow_struct_type( - cls: type, - converter: Any, -) -> pa.StructType: - """Derive the Arrow struct type for a dataclass class. - - The resulting struct has ``__dataclass.: large_string`` as its first field, - followed by one field per dataclass field. Field types are resolved via - ``converter`` (a ``UniversalTypeConverter``), so nested dataclasses - produce nested structs automatically once the converter has the dataclass - branch wired in. - - Args: - cls: A Python dataclass type. - converter: A ``UniversalTypeConverter`` instance used for field type - resolution. - - Returns: - A ``pa.StructType`` with ``__type`` as the first field. - - Raises: - TypeError: If `cls` is not a dataclass type. - """ - if not dataclasses.is_dataclass(cls) or not isinstance(cls, type): - raise TypeError(f"{cls!r} is not a dataclass type") - - hints = _get_type_hints_safe(cls) - fields: list[pa.Field] = [pa.field(DATACLASS_TYPE_FIELD, pa.large_string())] - for f in dataclasses.fields(cls): - if not f.init: - # Fields excluded from __init__ are not part of the serialized - # representation — they are typically derived/computed post-init. - continue - arrow_type = converter.python_type_to_arrow_type(hints[f.name]) - fields.append(pa.field(f.name, arrow_type)) - return pa.struct(fields) - - -def dataclass_to_struct_dict( - obj: Any, - field_converters: dict[str, Any], -) -> dict[str, Any]: - """Encode a dataclass instance to an Arrow-compatible struct dict. - - Args: - obj: A dataclass instance to encode. - field_converters: Pre-built per-field converter callables keyed by - field name. Build these once per type at converter-creation time - and reuse per row to avoid repeated type dispatch. - - Returns: - A dict with ``__dataclass.`` as the first key followed by encoded field values. - - Raises: - TypeError: If ``obj`` is not a dataclass instance (e.g. a class itself - or a non-dataclass value). - """ - # dataclasses.is_dataclass() returns True for both classes and instances; - # isinstance(obj, type) distinguishes: True for classes, False for instances. - if not dataclasses.is_dataclass(obj) or isinstance(obj, type): - raise TypeError(f"{obj!r} is not a dataclass instance") - - cls = type(obj) - type_str = f"{DATACLASS_TYPE_PREFIX}{cls.__module__}.{cls.__qualname__}" - result: dict[str, Any] = {DATACLASS_TYPE_FIELD: type_str} - for f in dataclasses.fields(cls): - if not f.init: - # Fields excluded from __init__ are not part of the serialized - # representation — they are typically derived/computed post-init. - continue - value = getattr(obj, f.name) - converter_fn = field_converters.get(f.name, lambda v: v) - result[f.name] = converter_fn(value) - return result - - -def struct_dict_to_dataclass( - struct_dict: dict[str, Any], - field_converters: dict[str, Any], - lookup_cache: dict[str, type], -) -> Any: - """Decode an Arrow struct dict to a Python dataclass instance. - - Uses a three-tier fallback: - - 1. **Import** — ``importlib``-import the class from its fully-qualified name. - 2. **Registry** — look up the FQCN in the process-global ``_DATACLASS_REGISTRY``. - 3. **Synthesize** — create a throwaway dataclass with ``dataclasses.make_dataclass`` - matching the struct's field names (all fields typed as ``Any``). - - Tier 3 never raises. A ``lookup_cache`` (keyed by FQCN) amortises repeated - resolution across rows in the same read operation. - - Args: - struct_dict: Arrow struct row dict as produced by ``pa.Table.to_pylist()``. - field_converters: Per-field Arrow->Python converter callables (keyed by - field name, excluding ``__type``). - lookup_cache: Mutable dict used as a per-read cache. Pass the same dict - for all rows in a read operation; clear between operations if needed. - - Returns: - A dataclass instance (real or synthesized) with field values set. - """ - type_str = struct_dict.get(DATACLASS_TYPE_FIELD) - - fqcn: str | None = None - class_name = "SynthesizedDataclass" - - if type_str and isinstance(type_str, str) and type_str.startswith(DATACLASS_TYPE_PREFIX): - candidate = type_str[len(DATACLASS_TYPE_PREFIX):] - if _FQCN_RE.match(candidate): - fqcn = candidate - class_name = fqcn.rsplit(".", 1)[-1] - else: - logger.warning( - "struct_dict_to_dataclass: invalid __type value %r — falling back to tier 3", - type_str, - ) - - cls: type | None = None - - if fqcn is not None: - # Check lookup cache first (amortises tiers 1-3 across rows) - if fqcn in lookup_cache: - cls = lookup_cache[fqcn] - else: - # Tier 1: import (disabled when ORCAPOD_DATACLASS_IMPORT=0) - if _TIER1_IMPORT_ENABLED: - module_path, _, class_attr = fqcn.rpartition(".") - try: - module = importlib.import_module(module_path) - resolved = getattr(module, class_attr) - if not dataclasses.is_dataclass(resolved) or not isinstance(resolved, type): - raise AttributeError( - f"{class_attr!r} in {module_path!r} is not a dataclass type" - ) - cls = resolved - lookup_cache[fqcn] = cls - except (ImportError, AttributeError) as exc: - logger.debug( - "struct_dict_to_dataclass: tier 1 import failed for %r: %s", - fqcn, exc, - ) - else: - logger.debug( - "struct_dict_to_dataclass: tier 1 disabled (ORCAPOD_DATACLASS_IMPORT=0), " - "skipping import for %r", - fqcn, - ) - - # Tier 2: registry - if cls is None: - cls = _DATACLASS_REGISTRY.get(fqcn) - if cls is not None: - lookup_cache[fqcn] = cls - - # Tier 3: synthesize (fqcn valid but unresolvable) - if cls is None: - field_names = [k for k in struct_dict if k != DATACLASS_TYPE_FIELD] - cls = dataclasses.make_dataclass( - class_name, [(name, typing.Any) for name in field_names] - ) - lookup_cache[fqcn] = cls - else: - # No valid fqcn — tier 3 with no caching (no stable key) - field_names = [k for k in struct_dict if k != DATACLASS_TYPE_FIELD] - cls = dataclasses.make_dataclass( - class_name, [(name, typing.Any) for name in field_names] - ) - - # Instantiate: apply field converters, skip the __type sentinel, and only - # pass keys that correspond to init=True fields on the resolved class. - # Filtering to init fields tolerates superset-schema structs (extra keys - # are silently dropped) and avoids passing init=False fields to __init__. - # - # A non-null value for a dropped key is flagged as a warning: NULL is the - # expected state (column present in schema but not applicable to this row / - # this class); a real value being discarded is a sign of a schema mismatch - # or a bug in the encoding pipeline. - init_field_names = {f.name for f in dataclasses.fields(cls) if f.init} - data_kwargs: dict[str, Any] = {} - for key, value in struct_dict.items(): - if key == DATACLASS_TYPE_FIELD: - continue - if key not in init_field_names: - if value is not None: - logger.warning( - "struct_dict_to_dataclass: field %r has a non-null value (%r) " - "but is not accepted by %r.__init__ — the value will be discarded. " - "This may indicate a schema mismatch or a bug in the encoding pipeline.", - key, value, cls, - ) - continue - converter_fn = field_converters.get(key, lambda v: v) - data_kwargs[key] = converter_fn(value) if value is not None else None - - return cls(**data_kwargs) diff --git a/src/orcapod/semantic_types/semantic_registry.py b/src/orcapod/semantic_types/semantic_registry.py deleted file mode 100644 index ff8c1a49..00000000 --- a/src/orcapod/semantic_types/semantic_registry.py +++ /dev/null @@ -1,246 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any - -from orcapod.protocols.semantic_types_protocols import SemanticStructConverterProtocol -from orcapod.semantic_types import pydata_utils - -# from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data -from orcapod.types import DataType, Schema -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - - -class SemanticTypeRegistry: - """ - Registry that manages semantic type converters using struct signature recognition. - - This registry maps Python types to PyArrow struct signatures, enabling - automatic detection and conversion of semantic types based on their - struct schema alone. - """ - - @staticmethod - def infer_python_schema_from_pylist(data: list[dict[str, Any]]) -> Schema: - """ - Infer Python schema from a list of dictionaries (pylist) - """ - return pydata_utils.infer_python_schema_from_pylist_data(data) - - @staticmethod - def infer_python_schema_from_pydict(data: dict[str, list[Any]]) -> Schema: - # TODO: consider which data type is more efficient and use that pylist or pydict - return pydata_utils.infer_python_schema_from_pylist_data( - pydata_utils.pydict_to_pylist(data) - ) - - def __init__( - self, converters: Mapping[str, SemanticStructConverterProtocol] | None = None - ): - # Bidirectional mappings between Python types and struct signatures - self._python_to_struct: dict[DataType, "pa.StructType"] = {} - self._struct_to_python: dict["pa.StructType", DataType] = {} - self._struct_to_converter: dict[ - "pa.StructType", SemanticStructConverterProtocol - ] = {} - - # Name mapping for convenience - self._name_to_converter: dict[str, SemanticStructConverterProtocol] = {} - self._struct_to_name: dict["pa.StructType", str] = {} - - # If initialized with a list of converters, register them - if converters: - for semantic_type_name, converter in converters.items(): - self.register_converter(semantic_type_name, converter) - - def register_converter( - self, semantic_type_name: str, converter: SemanticStructConverterProtocol - ) -> None: - """ - Register a semantic type converter. - - This creates bidirectional mappings between: - - Python type ↔ Arrow struct signature - - Arrow struct signature ↔ converter instance - - Optionally, a semantic type name can be provided. - """ - python_type = converter.python_type - struct_signature = converter.arrow_struct_type - - # Check for conflicts - if python_type in self._python_to_struct: - existing_struct = self._python_to_struct[python_type] - if existing_struct != struct_signature: - raise ValueError( - f"Python type {python_type} already registered with different struct signature. " - f"Existing: {existing_struct}, New: {struct_signature}" - ) - - if struct_signature in self._struct_to_python: - existing_python = self._struct_to_python[struct_signature] - if existing_python != python_type: - raise ValueError( - f"Struct signature {struct_signature} already registered with different Python type. " - f"Existing: {existing_python}, New: {python_type}" - ) - - # catch case where a different converter is already registered with the semantic type name - if existing_converter := self.get_converter_for_semantic_type( - semantic_type_name - ): - if existing_converter != converter: - raise ValueError( - f"Semantic type name '{semantic_type_name}' is already registered to {existing_converter}" - ) - - # Register bidirectional mappings - self._python_to_struct[python_type] = struct_signature - self._struct_to_python[struct_signature] = python_type - self._struct_to_converter[struct_signature] = converter - - self._name_to_converter[semantic_type_name] = converter - self._struct_to_name[struct_signature] = semantic_type_name - - def get_converter_for_python_type( - self, python_type: DataType - ) -> SemanticStructConverterProtocol | None: - """Get converter registered to the Python type.""" - # Direct lookup first - struct_signature = self._python_to_struct.get(python_type) - if struct_signature: - return self._struct_to_converter[struct_signature] - - # Handle subclass relationships - add safety check - for registered_type, struct_signature in self._python_to_struct.items(): - try: - if ( - isinstance(registered_type, type) - and isinstance(python_type, type) - and issubclass(python_type, registered_type) - ): - return self._struct_to_converter[struct_signature] - except TypeError: - # Handle cases where issubclass fails (e.g., with generic types) - continue - - return None - - def get_converter_for_semantic_type( - self, semantic_type_name: str - ) -> SemanticStructConverterProtocol | None: - """Get converter registered to the semantic type name.""" - return self._name_to_converter.get(semantic_type_name) - - def get_converter_for_struct_signature( - self, struct_signature: "pa.StructType" - ) -> SemanticStructConverterProtocol | None: - """ - Get converter registered to the Arrow struct signature. - """ - return self._struct_to_converter.get(struct_signature) - - def get_python_type_for_semantic_struct_signature( - self, struct_signature: "pa.StructType" - ) -> DataType | None: - """ - Get Python type registered to the Arrow struct signature. - """ - return self._struct_to_python.get(struct_signature) - - def get_semantic_struct_signature_for_python_type( - self, python_type: type - ) -> "pa.StructType | None": - """Get Arrow struct signature registered to the Python type.""" - return self._python_to_struct.get(python_type) - - def has_semantic_type(self, semantic_type_name: str) -> bool: - """Check if the semantic type name is registered.""" - return semantic_type_name in self._name_to_converter - - def has_python_type(self, python_type: type) -> bool: - """Check if the Python type is registered.""" - return python_type in self._python_to_struct - - def has_semantic_struct_signature(self, struct_signature: "pa.StructType") -> bool: - """Check if the struct signature is registered.""" - return struct_signature in self._struct_to_python - - def list_semantic_types(self) -> list[str]: - """Get all registered semantic type names.""" - return list(self._name_to_converter.keys()) - - def list_python_types(self) -> list[DataType]: - """Get all registered Python types.""" - return list(self._python_to_struct.keys()) - - def list_struct_signatures(self) -> list["pa.StructType"]: - """Get all registered struct signatures.""" - return list(self._struct_to_python.keys()) - - def find_semantic_fields_in_schema(self, schema: "pa.Schema") -> dict[str, str]: - """ - Find all semantic type fields in a schema by struct signature recognition. - - Args: - schema: PyArrow schema to examine - - Returns: - Dictionary mapping field names to semantic type names - - Example: - schema with fields: - - name: string - - file_path: struct - - location: struct - - Returns: {"file_path": "path", "location": "geolocation"} - """ - semantic_fields = {} - for field in schema: - if pa.types.is_struct(field.type) and field.type in self._struct_to_name: - semantic_fields[field.name] = self._struct_to_name[field.type] - return semantic_fields - - def get_semantic_field_info(self, schema: "pa.Schema") -> dict[str, dict[str, Any]]: - """ - Get detailed information about semantic fields in a schema. - - Returns: - Dictionary with field names as keys and info dictionaries as values. - Each info dict contains: semantic_type, python_type, struct_signature - """ - semantic_info = {} - for field in schema: - if pa.types.is_struct(field.type): - converter = self.get_converter_for_struct_signature(field.type) - if converter: - semantic_info[field.name] = { - "python_type": converter.python_type, - "struct_signature": field.type, - "converter": converter, - } - return semantic_info - - def validate_struct_signature( - self, struct_signature: "pa.StructType", expected_python_type: type - ) -> bool: - """ - Validate that a struct signature matches the expected Python type. - - Args: - struct_signature: Arrow struct type to validate - expected_python_type: Expected Python type - - Returns: - True if the struct signature is registered for the Python type - """ - registered_type = self.get_python_type_for_semantic_struct_signature( - struct_signature - ) - return registered_type == expected_python_type diff --git a/src/orcapod/semantic_types/semantic_struct_converters.py b/src/orcapod/semantic_types/semantic_struct_converters.py deleted file mode 100644 index 54be49a2..00000000 --- a/src/orcapod/semantic_types/semantic_struct_converters.py +++ /dev/null @@ -1,333 +0,0 @@ -""" -Struct-based semantic type system for OrcaPod. - -This replaces the metadata-based approach with explicit struct fields, -making semantic types visible in schemas and preserved through operations. -""" - -from __future__ import annotations - -import uuid as _uuid_module -from abc import ABC, abstractmethod -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from upath import UPath - -from orcapod.types import ContentHash -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa - - from orcapod.protocols.hashing_protocols import FileContentHasherProtocol -else: - pa = LazyModule("pyarrow") - - -class SemanticStructConverterBase: - """ - Base class providing common functionality for semantic struct converters. - - Subclasses only need to implement the abstract methods and can use - the common hashing infrastructure. - """ - - def __init__(self, semantic_type_name: str): - self._semantic_type_name = semantic_type_name - self._hasher_id = f"{self.semantic_type_name}_content_sha256" - - @property - def semantic_type_name(self) -> str: - """The name of the semantic type this converter handles.""" - return self._semantic_type_name - - @property - def hasher_id(self) -> str: - """Default hasher ID based on semantic type name""" - return self._hasher_id - - def _compute_content_hash(self, content: bytes) -> ContentHash: - """Compute SHA-256 hash of content bytes. - - Args: - content: Content to hash. - - Returns: - ``ContentHash`` with ``method="sha256"`` and the raw digest. - """ - import hashlib - - digest = hashlib.sha256(content).digest() - return ContentHash(method="sha256", digest=digest) - - def _format_semantic_hash(self, content_hash: ContentHash) -> str: - """Format a ``ContentHash`` into the standard semantic hash string. - - Always returns ``"{semantic_type_name}:{method}:{hex}"``, - e.g. ``"uuid:sha256:abc123"``. - - Args: - content_hash: Hash to format. - - Returns: - Formatted hash string with semantic type and algorithm prefix. - """ - return f"{self.semantic_type_name}:{content_hash.to_string(prefix_method=True)}" - - -class PathStructConverterBase(SemanticStructConverterBase, ABC): - """Base converter for file path types (Path and UPath). - - Extracts the shared conversion logic since Path and UPath have - identical APIs for the operations we need (str conversion, - construction from string, ``read_bytes``). - """ - - def __init__( - self, - name: str, - path_type: type, - file_hasher: "FileContentHasherProtocol", - ): - super().__init__(name) - self._python_type = path_type - self._field_name = name - self._file_hasher = file_hasher - self._arrow_struct_type = pa.struct([ - pa.field(name, pa.large_string()), - ]) - - @property - def python_type(self) -> type: - return self._python_type - - @property - def arrow_struct_type(self) -> "pa.StructType": - return self._arrow_struct_type - - @abstractmethod - def _make_path(self, path_str: str) -> Any: - """Construct the appropriate path object from a string.""" - ... - - def python_to_struct_dict(self, value: Any) -> dict[str, Any]: - """Convert path object to struct dictionary.""" - if not isinstance(value, self._python_type): - raise TypeError(f"Expected {self._python_type.__name__}, got {type(value)}") - return {self._field_name: str(value)} - - def struct_dict_to_python(self, struct_dict: dict[str, Any]) -> Any: - """Convert struct dictionary back to path object.""" - path_str = struct_dict.get(self._field_name) - if path_str is None: - raise ValueError(f"Missing '{self._field_name}' field in struct") - return self._make_path(path_str) - - def can_handle_python_type(self, python_type: type) -> bool: - """Check if this converter can handle the given Python type.""" - return issubclass(python_type, self._python_type) - - def can_handle_struct_type(self, struct_type: "pa.StructType") -> bool: - """Check if this converter can handle the given struct type.""" - for field in self._arrow_struct_type: - if ( - field.name not in struct_type.names - or struct_type[field.name].type != field.type - ): - return False - return True - - def is_semantic_struct(self, struct_dict: dict[str, Any]) -> bool: - """Check if a struct dictionary represents this semantic type.""" - return ( - set(struct_dict.keys()) == {self._field_name} - and isinstance(struct_dict[self._field_name], str) - ) - - def hash_struct_dict(self, struct_dict: dict[str, Any]) -> str: - """Compute hash of a path semantic type by hashing the file content. - - Returns a string of the form ``"{type}:{algorithm}:{hex}"``, - e.g. ``"path:sha256:abc123"``. - - Args: - struct_dict: Dict with the path field containing a file path string. - - Returns: - Hash string of the file content with semantic type and algorithm prefix. - - Raises: - FileNotFoundError: If the path does not exist. - IsADirectoryError: If the path is a directory. - """ - path_str = struct_dict.get(self._field_name) - if path_str is None: - raise ValueError(f"Missing '{self._field_name}' field in struct dict") - - path = self._make_path(path_str) - if not path.exists(): - raise FileNotFoundError(f"Path does not exist: {path}") - if path.is_dir(): - raise IsADirectoryError(f"Path is a directory: {path}") - - file_hash = self._file_hasher.hash_file(path) - return self._format_semantic_hash(file_hash) - - -class PythonPathStructConverter(PathStructConverterBase): - """Converter for pathlib.Path objects to/from semantic structs. - - Rejects ``UPath`` instances to avoid ambiguity with - ``UPathStructConverter``, since ``UPath`` is a ``Path`` subclass. - """ - - def __init__(self, file_hasher: "FileContentHasherProtocol"): - super().__init__("path", Path, file_hasher) - - def _make_path(self, path_str: str) -> Path: - return Path(path_str) - - def python_to_struct_dict(self, value: Any) -> dict[str, Any]: - """Convert Path to struct dictionary, rejecting UPath instances.""" - if isinstance(value, UPath): - raise TypeError( - f"Expected Path (not UPath), got {type(value)}. " - "Use UPathStructConverter for UPath instances." - ) - return super().python_to_struct_dict(value) - - def can_handle_python_type(self, python_type: type) -> bool: - """Check if this converter can handle the given Python type. - - Returns False for UPath (and its subclasses) to avoid ambiguity. - """ - if issubclass(python_type, UPath): - return False - return issubclass(python_type, Path) - - -class UPathStructConverter(PathStructConverterBase): - """Converter for universal_pathlib.UPath objects to/from semantic structs.""" - - def __init__(self, file_hasher: "FileContentHasherProtocol"): - super().__init__("upath", UPath, file_hasher) - - def _make_path(self, path_str: str) -> UPath: - return UPath(path_str) - - -class UUIDStructConverter(SemanticStructConverterBase): - """Converter for ``uuid.UUID`` objects to/from Arrow semantic structs. - - Stores UUIDs as fixed 16-byte binary values inside a single-field struct, - following the same pattern as ``PythonPathStructConverter`` and - ``UPathStructConverter``. - - Note: - ``uuid_utils.UUID`` objects (e.g. from ``uuid7()``) are accepted via - duck typing because they expose a ``.bytes`` attribute but do not - inherit from ``uuid.UUID``. - """ - - def __init__(self) -> None: - super().__init__("uuid") - self._python_type = _uuid_module.UUID - self._arrow_struct_type = pa.struct([pa.field("uuid", pa.binary(16))]) - - @property - def python_type(self) -> type: - """The Python type this converter handles (``uuid.UUID``).""" - return self._python_type - - @property - def arrow_struct_type(self) -> "pa.StructType": - """The Arrow struct type used for serialisation.""" - return self._arrow_struct_type - - def python_to_struct_dict(self, value: Any) -> dict[str, bytes]: - """Convert a UUID to a struct dictionary with a single ``uuid`` field. - - Accepts both ``uuid.UUID`` instances and duck-typed UUID-compatible - objects (e.g. ``uuid_utils.UUID``) that expose a ``.bytes`` attribute - returning 16 raw bytes. - - Args: - value: A ``uuid.UUID`` instance or compatible UUID-like object. - - Returns: - A dict with a single key ``"uuid"`` whose value is 16 raw bytes. - - Raises: - TypeError: If ``value`` is not a ``uuid.UUID`` instance or - compatible duck-typed UUID object. - """ - if isinstance(value, _uuid_module.UUID): - return {"uuid": value.bytes} - # Accept uuid_utils.UUID and other duck-typed UUID objects - raw = getattr(value, "bytes", None) - if isinstance(raw, bytes) and len(raw) == 16: - return {"uuid": raw} - raise TypeError( - f"Expected uuid.UUID or compatible UUID object, got {type(value)}" - ) - - def struct_dict_to_python(self, struct_dict: dict[str, Any]) -> _uuid_module.UUID: - """Convert a struct dictionary back to a ``uuid.UUID`` instance. - - Args: - struct_dict: Dict with a ``"uuid"`` key containing 16 raw bytes - (``bytes`` or ``bytearray``). - - Returns: - A ``uuid.UUID`` constructed from the raw bytes. - - Raises: - ValueError: If the ``"uuid"`` key is absent from ``struct_dict``. - """ - raw = struct_dict.get("uuid") - if raw is None: - raise ValueError("Missing 'uuid' field in struct dict") - return _uuid_module.UUID(bytes=bytes(raw)) - - def can_handle_python_type(self, python_type: type) -> bool: - """Check if this converter can handle the given Python type. - - Args: - python_type: The Python type to check. - - Returns: - ``True`` if ``python_type`` is ``uuid.UUID`` or a subclass of it. - """ - return issubclass(python_type, self._python_type) - - def can_handle_struct_type(self, struct_type: "pa.StructType") -> bool: - """Check if this converter can handle the given Arrow struct type. - - Args: - struct_type: The Arrow struct type to check. - - Returns: - ``True`` if ``struct_type`` equals the UUID Arrow struct type. - """ - return struct_type == self._arrow_struct_type - - def hash_struct_dict(self, struct_dict: dict[str, Any]) -> str: - """Compute a SHA-256 hash of the UUID from its struct dictionary representation. - - Hashes the raw 16 UUID bytes directly. - - Args: - struct_dict: Dict with a ``"uuid"`` key containing 16 raw bytes. - - Returns: - Hash string of the form ``"uuid:sha256:"``. - - Raises: - ValueError: If the ``"uuid"`` key is absent from ``struct_dict``. - """ - raw = struct_dict.get("uuid") - if raw is None: - raise ValueError("Missing 'uuid' field in struct dict") - content_hash = self._compute_content_hash(bytes(raw)) - return self._format_semantic_hash(content_hash) diff --git a/src/orcapod/semantic_types/universal_converter.py b/src/orcapod/semantic_types/universal_converter.py index 72e71a77..8150776c 100644 --- a/src/orcapod/semantic_types/universal_converter.py +++ b/src/orcapod/semantic_types/universal_converter.py @@ -11,38 +11,29 @@ from __future__ import annotations +import contextvars import hashlib import logging import types import typing -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Mapping from datetime import datetime, timezone # Handle generic types from typing import TYPE_CHECKING, Any, TypedDict, get_args, get_origin from orcapod.contexts import DataContext, resolve_context -from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data from orcapod.types import DataType, Schema, SchemaLike from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol else: pa = LazyModule("pyarrow") -import dataclasses - -from orcapod.semantic_types.dataclass_encoding import ( - DATACLASS_TYPE_FIELD, - _get_type_hints_safe, - dataclass_to_arrow_struct_type, - dataclass_to_struct_dict, - has_dataclass_type_sentinel, - struct_dict_to_dataclass, -) - logger = logging.getLogger(__name__) @@ -52,6 +43,14 @@ # referencing _PYTHON_TO_ARROW_MAP directly. _PYTHON_TO_ARROW_MAP: "dict | None" = None +# Context variable for cycle detection in register_python_class. +# Using a ContextVar (rather than an instance attribute) keeps it thread-safe, +# coroutine-safe, and explicitly scoped to the active call chain without +# polluting the converter instance with temporary state. +_register_in_progress: contextvars.ContextVar[set[type] | None] = contextvars.ContextVar( + "_register_in_progress", default=None +) + def _get_python_to_arrow_map() -> dict: """Return the Python→Arrow type map, building it on first call.""" @@ -117,6 +116,12 @@ def _get_python_to_arrow_map() -> dict: return _PYTHON_TO_ARROW_MAP +# Cache for the set of Python types that UniversalTypeConverter handles natively. +# Built lazily by get_native_python_types() from _get_python_to_arrow_map() so +# that datetime, numpy types, and any future additions are captured automatically. +_ARROW_NATIVE_TYPE_KEYS: frozenset[type] | None = None + + def _is_optional_type(python_type: DataType) -> bool: """Return True if python_type is T | None (Optional[T]). @@ -148,12 +153,11 @@ class UniversalTypeConverter: def __init__( self, - semantic_registry: SemanticTypeRegistry | None = None, datetime_timezone: typing.Literal["strict", "coerce_utc"] = "strict", + logical_type_registry: LogicalTypeRegistry | None = None, ): """ Args: - semantic_registry: Optional registry of semantic type converters. datetime_timezone: How to handle naive (timezone-less) ``datetime`` values when converting Python → Arrow. @@ -163,9 +167,12 @@ def __init__( ``"coerce_utc"`` — silently attach ``timezone.utc`` to naive datetimes before writing to Arrow. Use this when you know that all naive datetimes in your data represent UTC. + logical_type_registry: Optional registry of ``LogicalType`` instances. + When provided, extension-type identity takes priority over the + shape-based logical type system at encoding time. """ - self.semantic_registry = semantic_registry self._datetime_timezone = datetime_timezone + self._logical_type_registry = logical_type_registry # Cache for created TypedDict classes self._struct_signature_to_typeddict: dict[pa.StructType, DataType] = {} @@ -179,7 +186,560 @@ def __init__( # Cache for type mappings self._python_to_arrow_types: dict[DataType, pa.DataType] = {} self._arrow_to_python_types: dict[pa.DataType, DataType] = {} - self._dataclass_lookup_cache: dict[str, type] = {} + + + @classmethod + def get_native_python_types(cls) -> frozenset[type]: + """Return the set of Python types that this converter handles natively. + + Derived lazily from ``_get_python_to_arrow_map()`` so that + ``datetime.datetime``, numpy scalar types, and any future additions + are captured without hard-coding them here. ``type(None)`` is always + included because ``NoneType`` is produced by ``Optional[T]`` / + ``T | None`` unwrapping but may not appear as a key in the map. + + Returns: + Frozen set of Python ``type`` objects with built-in Arrow mappings. + """ + global _ARROW_NATIVE_TYPE_KEYS + if _ARROW_NATIVE_TYPE_KEYS is None: + _ARROW_NATIVE_TYPE_KEYS = frozenset( + k for k in _get_python_to_arrow_map() if isinstance(k, type) + ) | {type(None)} + return _ARROW_NATIVE_TYPE_KEYS + + def ensure_types_registered_for_schemas(self, *schemas: Schema) -> None: + """Ensure a LogicalType is registered for every annotation in schemas. + + Calls ``register_python_class`` for each annotation, which recursively + resolves nested types and synthesises via factory if needed. + When no ``LogicalTypeRegistry`` is configured, this is a no-op. + + Args: + *schemas: One or more ``Schema`` mappings (column name → Python type). + + Raises: + TypeError: If a leaf class has no registered ``LogicalType`` and + no registered factory covers it. + """ + if self._logical_type_registry is None: + return + for schema in schemas: + for annotation in schema.values(): + self.register_python_class(annotation) + + def register_python_class(self, annotation: Any) -> "pa.DataType": + """Register a Python type annotation and return its Arrow type. + + Traverses generic annotations recursively. For each concrete class found, + either returns from the primitive map or registry (cache hit), or + synthesises via factory and registers the result. + + Cycle detection uses a ``ContextVar`` (``_register_in_progress``) rather + than instance state, so it is thread-safe, coroutine-safe, and correctly + detects cycles that cross factory call-backs (e.g. a dataclass with a + field of its own type). + + Args: + annotation: A Python type or generic alias (e.g. ``list[str]``, + ``Optional[uuid.UUID]``, a dataclass type). + + Returns: + The Arrow ``pa.DataType`` corresponding to ``annotation``. + + Raises: + TypeError: If a concrete class has no registered ``LogicalType`` and + no factory covers it, or if a circular dependency is detected. + ValueError: If a complex (non-Optional) union is encountered. + """ + in_progress = _register_in_progress.get() + if in_progress is None: + # Top-level call: initialize a fresh in-progress set and register it + # in the context so recursive calls (including factory call-backs) reuse it. + fresh: set[type] = set() + token = _register_in_progress.set(fresh) + try: + return self._register_python_class_impl(annotation, fresh) + finally: + _register_in_progress.reset(token) + # Nested call (direct recursion or factory call-back): reuse the existing set. + return self._register_python_class_impl(annotation, in_progress) + + def _register_python_class_impl(self, annotation: Any, in_progress: set[type]) -> "pa.DataType": + """Internal recursive implementation of ``register_python_class``. + + Args: + annotation: The annotation to resolve. + in_progress: The mutable cycle-detection set for the current call chain. + Shared across factory call-backs via ``_register_in_progress`` ContextVar. + """ + import types as _types_mod + + type_map = _get_python_to_arrow_map() + + # Primitive map hit + if annotation in type_map: + return type_map[annotation] + + origin = get_origin(annotation) + args = get_args(annotation) + + # Optional[T] / T | None → strip None arm + if origin is typing.Union or origin is _types_mod.UnionType: + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1: + return self.register_python_class(non_none[0]) + raise ValueError( + f"Complex unions with multiple non-None types are not supported: " + f"{annotation!r}. Only Optional[T] (T | None) is allowed." + ) + + # list[T] → pa.large_list(T). + # Raise if T resolves to an extension type: Arrow forbids extension types inside + # list value fields (ET1/ET2 in DESIGN_ISSUES.md). Fail loudly now rather than + # silently dropping type information and failing mysteriously on read. + # Native list-of-logical-type support is planned in PLT-1732 (ListLogicalType). + if origin is list: + if not args: + raise ValueError( + "Unparameterized 'list' is not supported. Use 'list[T]' with a concrete " + "element type (e.g. list[int], list[str])." + ) + inner = self.register_python_class(args[0]) + if isinstance(inner, pa.ExtensionType): + raise ValueError( + f"'list[{args[0]}]' is not yet supported: the element type maps to Arrow " + f"extension type {inner.extension_name!r}, which cannot be preserved inside " + f"a list value field due to an Arrow limitation (ET2 in DESIGN_ISSUES.md). " + f"Native list-of-logical-type support is tracked in PLT-1732." + ) + return pa.large_list(inner) + + # set[T] → pa.large_list(T). Same restriction as list[T]. + if origin is set: + if not args: + raise ValueError( + "Unparameterized 'set' is not supported. Use 'set[T]' with a concrete " + "element type (e.g. set[int], set[str])." + ) + inner = self.register_python_class(args[0]) + if isinstance(inner, pa.ExtensionType): + raise ValueError( + f"'set[{args[0]}]' is not yet supported: the element type maps to Arrow " + f"extension type {inner.extension_name!r}, which cannot be preserved inside " + f"a list value field due to an Arrow limitation (ET2 in DESIGN_ISSUES.md). " + f"Native set-of-logical-type support is tracked in PLT-1732." + ) + return pa.large_list(inner) + + # dict[K, V] → pa.large_list(struct{key: K, value: V}). + # Raise if K or V resolves to an extension type: the key/value land inside struct + # fields, which also forbids extension types (ET1 in DESIGN_ISSUES.md). + if origin is dict: + if len(args) < 2: + raise ValueError( + "Unparameterized 'dict' is not supported. Use 'dict[K, V]' with concrete " + "key and value types (e.g. dict[str, int])." + ) + key_arrow = self.register_python_class(args[0]) + val_arrow = self.register_python_class(args[1]) + if isinstance(key_arrow, pa.ExtensionType): + raise ValueError( + f"'dict[{args[0]}, ...]' is not yet supported: the key type maps to Arrow " + f"extension type {key_arrow.extension_name!r}, which cannot be preserved " + f"inside a struct field due to an Arrow limitation (ET1 in DESIGN_ISSUES.md). " + f"Native dict-of-logical-type support is tracked in PLT-1732." + ) + if isinstance(val_arrow, pa.ExtensionType): + raise ValueError( + f"'dict[..., {args[1]}]' is not yet supported: the value type maps to Arrow " + f"extension type {val_arrow.extension_name!r}, which cannot be preserved " + f"inside a struct field due to an Arrow limitation (ET1 in DESIGN_ISSUES.md). " + f"Native dict-of-logical-type support is tracked in PLT-1732." + ) + return pa.large_list( + pa.struct([pa.field("key", key_arrow), pa.field("value", val_arrow)]) + ) + + # Concrete class — registry or factory dispatch + if isinstance(annotation, type): + if self._logical_type_registry is None: + # No registry — return primitive Arrow type if available, else raise + raise TypeError( + f"No LogicalTypeRegistry configured — cannot register {annotation!r}. " + f"Provide logical_type_registry at converter construction time." + ) + + # Registry hit (already synthesised) + lt = self._logical_type_registry.get_by_python_type(annotation) + if lt is not None: + return lt.get_arrow_extension_type() + + # Cycle detection (via the shared ContextVar-backed in_progress set) + if annotation in in_progress: + raise TypeError( + f"Circular type dependency detected while synthesising " + f"LogicalType for {annotation!r}." + ) + + # Factory dispatch via MRO walk + factory = self._find_factory_for_class(annotation) + if factory is None: + raise TypeError( + f"No LogicalType or LogicalTypeFactory registered for {annotation!r}. " + f"Register a factory: converter.register_logical_type_factory(factory, " + f"python_bases=[])" + ) + + in_progress.add(annotation) + try: + lt = factory.create_for_python_type(annotation, converter=self) + self._logical_type_registry.register_logical_type(lt) + finally: + in_progress.discard(annotation) + + return lt.get_arrow_extension_type() + + raise ValueError(f"Unsupported annotation: {annotation!r}") + + def _find_factory_for_class( + self, + python_type: type, + ) -> "LogicalTypeFactoryProtocol | None": + """Find the most-specific registered factory for ``python_type``. + + Walks ``python_type.__mro__`` and returns the first factory in + ``_python_class_factories`` whose ``supports_class(python_type)`` returns True. + Falls back to an ``issubclass`` scan for ABC-registered factories. + + Args: + python_type: Concrete Python class to find a factory for. + + Returns: + The matching ``LogicalTypeFactoryProtocol``, or ``None`` if none found. + """ + factories = self._logical_type_registry._python_class_factories + + # MRO walk — most-specific base first + for base in python_type.__mro__: + factory = factories.get(base) + if factory is not None: + if hasattr(factory, "supports_class") and factory.supports_class(python_type): + return factory + elif not hasattr(factory, "supports_class"): + # Factories without supports_class are treated as unconditional matches + return factory + + # issubclass fallback for ABC-registered factories + for base, factory in factories.items(): + try: + if issubclass(python_type, base): + if hasattr(factory, "supports_class"): + if factory.supports_class(python_type): + return factory + else: + return factory + except TypeError: + continue + + return None + + def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Register extension types found in ``arrow_type`` and return the resolved type. + + Traverses Arrow types recursively in a bottom-up manner: + + - Primitives are returned unchanged. + - ``pa.ExtensionType`` instances that are already registered are returned as-is. + - Unregistered extension types: the storage type is resolved first (bottom-up), + then the factory dispatches on the ``"category"`` metadata key. + - Structs: each field's type is resolved; a new struct with resolved fields is returned. + - Lists: the value type is resolved; a new list type with the resolved value is returned. + + Args: + arrow_type: An Arrow type to traverse and register. + + Returns: + The resolved Arrow type with extension types embedded. + """ + # Extension type + if isinstance(arrow_type, pa.ExtensionType): + ext_name = arrow_type.extension_name + if self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name(ext_name) + if lt is not None: + return lt.get_arrow_extension_type() + # Registry miss — extract info and register + raw_meta = arrow_type.__arrow_ext_serialize__() + ext_meta = raw_meta if raw_meta else None + resolved_storage = self.register_storage_type(arrow_type.storage_type) + return self.register_arrow_extension(ext_name, ext_meta, resolved_storage) + + # Struct type — recurse into each field, preserving field-level metadata. + # Strip any extension type from field types before embedding (ET1: Arrow/Polars + # cannot construct arrays whose struct fields are pa.ExtensionType nodes). + if pa.types.is_struct(arrow_type): + resolved_fields = [] + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + resolved_type = self.register_storage_type(field.type) + if isinstance(resolved_type, pa.ExtensionType): + resolved_type = resolved_type.storage_type # strip: ET1 + resolved_fields.append( + pa.field(field.name, resolved_type, nullable=field.nullable, metadata=field.metadata) + ) + return pa.struct(resolved_fields) + + # Large list type — preserve value field metadata (used by ARROW:extension:* channel). + # Strip any extension type from the value type before embedding (ET1). + if pa.types.is_large_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + if isinstance(resolved_value, pa.ExtensionType): + resolved_value = resolved_value.storage_type # strip: ET1 + return pa.large_list( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) + + # List type — strip any extension type from the value type (ET1). + if pa.types.is_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + if isinstance(resolved_value, pa.ExtensionType): + resolved_value = resolved_value.storage_type # strip: ET1 + return pa.list_( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) + + # All other types (primitives, timestamps, binary, etc.) — return as-is + return arrow_type + + def apply_extension_types(self, table: "pa.Table") -> "pa.Table": + """Re-wrap *table* columns into their registered Arrow extension types. + + A convenience wrapper around the module-level ``apply_extension_types`` + function that uses this converter's own logical type registry. No-op + when the registry is absent or when the table contains no columns with + ``ARROW:extension:name`` field metadata. + + Call ``self.register_discovered_extensions(table.schema)`` first to + ensure all extension types in the schema are registered before calling + this method. + + Args: + table: Arrow table whose columns may contain ``ARROW:extension:*`` + field metadata from a Parquet/IPC read, but were loaded as plain + storage types. + + Returns: + A new ``pa.Table`` with extension-typed columns re-wrapped, or the + original *table* unchanged if no re-wrapping is needed. + """ + if self._logical_type_registry is None: + return table + from orcapod.extension_types.database_hooks import ( + apply_extension_types as _apply_ext, + ) + return _apply_ext(table, self._logical_type_registry) + + def register_discovered_extensions(self, schema: "pa.Schema") -> None: + """Register any extension types found in ``schema`` that are not yet known. + + A convenience wrapper around the module-level ``register_discovered_extensions`` + function. Walks ``schema`` recursively and registers each discovered extension + type via this converter's ``register_arrow_extension``. Already-registered types + are skipped. No-op when the schema contains no extension types. + + Call this before ``apply_extension_types`` when reading a table from Parquet or + IPC to ensure all extension types in the schema are registered: + + converter.register_discovered_extensions(table.schema) + table = converter.apply_extension_types(table) + + Args: + schema: The Arrow schema to inspect for extension types. + """ + from orcapod.extension_types.database_hooks import ( + register_discovered_extensions as _reg_disc, + ) + _reg_disc(self, schema) + + def load_extension_types(self, table: "pa.Table") -> "pa.Table": + """Register and apply extension types for *table* in one step. + + Convenience wrapper that calls ``register_discovered_extensions`` followed + by ``apply_extension_types``. Use this as the standard post-read step after + loading a table from Parquet or IPC: + + table = converter.load_extension_types(pq.read_table(path)) + + Args: + table: Arrow table as returned by a Parquet or IPC read, whose columns + may carry ``ARROW:extension:*`` field metadata but were loaded as + plain storage types. + + Returns: + A new ``pa.Table`` with extension-typed columns re-wrapped, or the + original *table* unchanged if no extension types are present. + """ + self.register_discovered_extensions(table.schema) + return self.apply_extension_types(table) + + def register_arrow_extension( + self, + arrow_extension_name: str, + extension_metadata: bytes | None, + storage_type: "pa.DataType", + ) -> "pa.DataType": + """Register an extension type from (name, metadata, storage_type) info. + + Called by ``register_storage_type`` for in-memory ``pa.ExtensionType`` objects, + and by ``register_discovered_extensions`` for the field-metadata (Parquet) channel. + The ``storage_type`` must already be resolved (nested extension types registered). + + Args: + arrow_extension_name: Arrow extension name (``ARROW:extension:name``). + extension_metadata: Raw metadata bytes, expected to be UTF-8 JSON with + at least a ``"category"`` key. ``None`` or empty bytes if absent. + storage_type: Underlying Arrow storage type (already bottom-up resolved). + + Returns: + The Arrow extension type after registration. + + Raises: + ValueError: If metadata is missing, malformed, lacks ``"category"``, or + no factory is registered for the category. + """ + import json as _json + + if self._logical_type_registry is None: + raise ValueError( + f"No LogicalTypeRegistry configured — cannot register extension type " + f"{arrow_extension_name!r}." + ) + + # Registry hit — already registered + lt = self._logical_type_registry.get_by_arrow_extension_name(arrow_extension_name) + if lt is not None: + return lt.get_arrow_extension_type() + + # Missing metadata — cannot auto-register + if not extension_metadata: + raise ValueError( + f"Extension type {arrow_extension_name!r} has no extension metadata. " + f"Types without a metadata category tag cannot be auto-registered via a factory. " + f"Pre-register them explicitly via converter.register_logical_type(lt)." + ) + + # Parse JSON metadata + try: + metadata_dict = _json.loads(extension_metadata.decode("utf-8")) + except (UnicodeDecodeError, _json.JSONDecodeError) as exc: + raise ValueError( + f"Extension type {arrow_extension_name!r} has metadata that is not valid " + f"UTF-8 JSON: {extension_metadata!r}. Parse error: {exc}." + ) from exc + + if not isinstance(metadata_dict, dict): + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata decoded to a non-object " + f"JSON value: {metadata_dict!r}." + ) + + if "category" not in metadata_dict: + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata has no \"category\" key: " + f"{metadata_dict}." + ) + + category = metadata_dict["category"] + if not isinstance(category, str): + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata \"category\" is not a " + f"string: {category!r}." + ) + + # Look up factory by category + factory = self._logical_type_registry._category_factories.get(category) + if factory is None: + raise ValueError( + f"No LogicalTypeFactory registered for category {category!r}. " + f"Cannot register extension type {arrow_extension_name!r}." + ) + + # Reconstruct and register + logical_type = factory.reconstruct_from_arrow( + arrow_extension_name, storage_type, metadata_dict, converter=self + ) + self._logical_type_registry.register_logical_type(logical_type) + return logical_type.get_arrow_extension_type() + + def python_to_storage(self, value: Any, annotation: Any) -> Any: + """Convert a Python value to its Arrow storage representation. + + Thin wrapper over ``get_python_to_arrow_converter`` for use by + ``DataclassLogicalType`` and other logical types that delegate per-field + conversion back to the converter. + + Args: + value: A Python object. + annotation: The Python type annotation for ``value``. + + Returns: + A value in Arrow storage format. + """ + converter_fn = self.get_python_to_arrow_converter(annotation) + return converter_fn(value) + + def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: + """Convert an Arrow storage value back to a Python object. + + Args: + storage_value: A scalar or element from an Arrow storage array. + annotation: The Python type annotation to convert back to. + + Returns: + A Python object of the type described by ``annotation``. + """ + arrow_type = self.python_type_to_arrow_type(annotation) + converter_fn = self.get_arrow_to_python_converter(arrow_type) + return converter_fn(storage_value) + + def register_logical_type(self, lt: "LogicalTypeProtocol") -> None: + """Register a ``LogicalTypeProtocol`` instance. + + Pass-through to the internal ``LogicalTypeRegistry``. + + Args: + lt: The logical type to register. + """ + if self._logical_type_registry is None: + raise ValueError("No LogicalTypeRegistry configured on this converter.") + self._logical_type_registry.register_logical_type(lt) + + def register_logical_type_factory( + self, + factory: "LogicalTypeFactoryProtocol", + *, + category: str | None = None, + python_bases: Iterable[type] = (), + ) -> None: + """Register a ``LogicalTypeFactoryProtocol`` instance. + + Pass-through to the internal ``LogicalTypeRegistry``. + + Args: + factory: The factory to register. + category: If given, registers factory as the read-side handler for + Arrow extension types with this ``"category"`` metadata value. + python_bases: Zero or more Python base classes to register as write-side + dispatch keys for this factory. + """ + if self._logical_type_registry is None: + raise ValueError("No LogicalTypeRegistry configured on this converter.") + self._logical_type_registry.register_logical_type_factory( + factory, category=category, python_bases=python_bases + ) def python_type_to_arrow_type(self, python_type: DataType) -> pa.DataType: """ @@ -220,14 +780,18 @@ def arrow_type_to_python_type(self, arrow_type: pa.DataType) -> DataType: This is the main entry point for Arrow → Python type conversion. Results are cached for performance. """ - # Check cache first - if arrow_type in self._arrow_to_python_types: - return self._arrow_to_python_types[arrow_type] + try: + if arrow_type in self._arrow_to_python_types: + return self._arrow_to_python_types[arrow_type] + except TypeError: + # ExtensionType instances are not always hashable — skip the cache. + return self._convert_arrow_to_python(arrow_type) - # Convert and cache result python_type = self._convert_arrow_to_python(arrow_type) - self._arrow_to_python_types[arrow_type] = python_type - + try: + self._arrow_to_python_types[arrow_type] = python_type + except TypeError: + pass # Unhashable type — skip caching. return python_type def arrow_schema_to_python_schema(self, arrow_schema: pa.Schema) -> Schema: @@ -399,8 +963,14 @@ def get_arrow_to_python_converter( This creates and caches conversion functions for optimal performance during data conversion operations. """ - if arrow_type in self._arrow_to_python_converters: - return self._arrow_to_python_converters[arrow_type] + try: + if arrow_type in self._arrow_to_python_converters: + return self._arrow_to_python_converters[arrow_type] + except TypeError: + # Some pa.DataType subclasses (e.g. pa.ExtensionType instances) are not + # hashable and will raise TypeError on dict lookup. Fall through to + # create the converter without caching. + return self._create_arrow_to_python_converter(arrow_type) # Create conversion function converter = self._create_arrow_to_python_converter(arrow_type) @@ -415,22 +985,18 @@ def _convert_python_to_arrow(self, python_type: DataType) -> pa.DataType: if python_type in type_map: return type_map[python_type] - # Check semantic registry for registered types - if self.semantic_registry: - converter = self.semantic_registry.get_converter_for_python_type( - python_type - ) - if converter: - return converter.arrow_struct_type + # Check LogicalTypeRegistry — extension-type identity takes priority over shape-based system. + # Guard with isinstance(…, type) because get_by_python_type is keyed on concrete classes; + # generic aliases (list[T], Optional[T], etc.) will never be registered there. + if self._logical_type_registry is not None and isinstance(python_type, type): + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + return lt.get_arrow_extension_type() # Handle typeddict look up if python_type in self._typeddict_to_struct_signature: return self._typeddict_to_struct_signature[python_type] - # Dataclass types → struct with __type sentinel - if dataclasses.is_dataclass(python_type) and isinstance(python_type, type): - return dataclass_to_arrow_struct_type(python_type, self) - # Check generic types origin = get_origin(python_type) args = get_args(python_type) @@ -511,6 +1077,14 @@ def _convert_arrow_to_python(self, arrow_type: pa.DataType) -> type | Any: if pa.types.is_null(arrow_type): return Any + # Check LogicalTypeRegistry for extension types + if isinstance(arrow_type, pa.ExtensionType) and self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name( + arrow_type.extension_name + ) + if lt is not None: + return lt.python_type + # Handle basic types if pa.types.is_integer(arrow_type): return int @@ -529,48 +1103,6 @@ def _convert_arrow_to_python(self, arrow_type: pa.DataType) -> type | Any: # Handle struct types elif pa.types.is_struct(arrow_type): - # Check if it's a registered semantic type first - if self.semantic_registry: - python_type = self.semantic_registry.get_python_type_for_semantic_struct_signature( - arrow_type - ) - if python_type: - return python_type - - # Dataclass structs: synthesize a concrete dataclass from the struct's - # field definitions. The sentinel field is excluded; each remaining - # field's Arrow type is recursively converted to a Python type. - # The result is cached automatically by arrow_type_to_python_type()'s - # _arrow_to_python_types dict so the same class is reused for the - # same struct schema. - if has_dataclass_type_sentinel(arrow_type): - # Respect per-field nullability: nullable Arrow fields become - # Optional[T] annotations so that the synthesized dataclass - # correctly conveys that those fields can hold None, and so - # that round-trips through python_schema_to_arrow_schema - # preserve the nullable flag. - fields = [ - ( - field.name, - self.arrow_type_to_python_type(field.type) | None - if field.nullable - else self.arrow_type_to_python_type(field.type), - ) - for field in arrow_type - if field.name != DATACLASS_TYPE_FIELD - ] - # Include nullability in the hash so that two structs with - # identical field names and Arrow types but different per-field - # nullability produce distinct class names in the lookup cache. - field_parts = [ - f"{f.name}:{'?' if f.nullable else ''}{f.type}" - for f in arrow_type - if f.name != DATACLASS_TYPE_FIELD - ] - name_hash = hashlib.md5("|".join(field_parts).encode()).hexdigest()[:8] - class_name = f"_SynthesizedDataclass_{name_hash}" - return dataclasses.make_dataclass(class_name, fields) - # Check if it is heterogeneous tuple if len(arrow_type) > 0 and all( field.name.startswith("f") and field.name[1:].isdigit() @@ -743,28 +1275,20 @@ def _create_python_to_arrow_converter( ) -> Callable[[Any], Any]: """Create a cached conversion function for Python → Arrow values.""" + # Check LogicalTypeRegistry first — extension-type identity takes priority. + # Guard with isinstance(…, type) because get_by_python_type is keyed on concrete classes; + # generic aliases (list[T], Optional[T], etc.) will never be registered there. + if self._logical_type_registry is not None and isinstance(python_type, type): + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + _lt = lt + _self = self + return lambda value: _lt.python_to_storage(value, _self) + # Get the Arrow type for this Python type # TODO: check if this step is necessary _ = self.python_type_to_arrow_type(python_type) - # Check for semantic type first - if self.semantic_registry: - converter = self.semantic_registry.get_converter_for_python_type( - python_type - ) - if converter: - return converter.python_to_struct_dict - - # Dataclass instances → struct dict with __type sentinel - if dataclasses.is_dataclass(python_type) and isinstance(python_type, type): - hints = _get_type_hints_safe(python_type) - field_converters = { - f.name: self.get_python_to_arrow_converter(hints[f.name]) - for f in dataclasses.fields(python_type) - if f.init # skip init=False fields: not part of the serialized repr - } - return lambda obj: dataclass_to_struct_dict(obj, field_converters) - # Create conversion function based on type # Without this guard, datetime would reach the `origin is None` catch-all @@ -854,23 +1378,19 @@ def _create_arrow_to_python_converter( ) -> Callable[[Any], Any]: """Create a cached conversion function for Arrow → Python values.""" + # Check LogicalTypeRegistry for extension types + if isinstance(arrow_type, pa.ExtensionType) and self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name( + arrow_type.extension_name + ) + if lt is not None: + _lt = lt + _self = self + return lambda storage_value: _lt.storage_to_python(storage_value, _self) + # Get the Python type for this Arrow type python_type = self.arrow_type_to_python_type(arrow_type) - # Check for semantic type first - if self.semantic_registry and pa.types.is_struct(arrow_type): - registered_python_type = ( - self.semantic_registry.get_python_type_for_semantic_struct_signature( - arrow_type - ) - ) - if registered_python_type: - converter = self.semantic_registry.get_converter_for_python_type( - registered_python_type - ) - if converter: - return converter.struct_dict_to_python - # Handle basic types - no conversion needed if ( pa.types.is_integer(arrow_type) @@ -935,16 +1455,6 @@ def _create_arrow_to_python_converter( # Handle struct types - heterogeneous tuple or dynamic TypedDict elif pa.types.is_struct(arrow_type): - # Dataclass structs: per-row dispatch via __type value - if has_dataclass_type_sentinel(arrow_type): - field_converters = { - field.name: self.get_arrow_to_python_converter(field.type) - for field in arrow_type - if field.name != DATACLASS_TYPE_FIELD - } - cache = self._dataclass_lookup_cache - return lambda d: struct_dict_to_dataclass(d, field_converters, cache) - # if python_type if python_type is tuple or get_origin(python_type) is tuple: n = len(get_args(python_type)) @@ -1004,7 +1514,6 @@ def clear_cache(self) -> None: self._arrow_to_python_converters.clear() self._python_to_arrow_types.clear() self._arrow_to_python_types.clear() - self._dataclass_lookup_cache.clear() def get_cache_stats(self) -> dict[str, int]: """Get statistics about cache usage (useful for debugging/optimization).""" diff --git a/superpowers/plans/2026-06-14-extension-type-registry.md b/superpowers/plans/2026-06-14-extension-type-registry.md new file mode 100644 index 00000000..14a119f5 --- /dev/null +++ b/superpowers/plans/2026-06-14-extension-type-registry.md @@ -0,0 +1,925 @@ +# ExtensionTypeRegistry Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement `ExtensionTypeRegistry` in `src/orcapod/extension_types/registry.py`, wiring up both PyArrow and Polars global extension type registries on each `register()` call. + +**Architecture:** A plain Python class with two internal dicts (`_by_name`, `_by_python_type`) for converter lookup, plus two module-level shadow dicts (`_ARROW_REGISTRY`, `_POLARS_REGISTRY`) that track what has been registered in the process-global PA/Polars registries. `register()` validates against both the instance dict (duplicate check) and the shadow dicts (equivalence/external-conflict check), then dynamically creates and registers `pa.ExtensionType` and `pl.BaseExtension` subclasses via `type()`. + +**Tech Stack:** Python 3.12, PyArrow ≥ 20, Polars ≥ 1.36, pytest, uv + +**Spec:** `superpowers/specs/2026-06-14-extension-type-registry-design.md` + +--- + +## File map + +| File | Action | +|---|---| +| `pyproject.toml` | Modify — restore range constraint `polars>=1.36.0` | +| `src/orcapod/extension_types/registry.py` | **Create** — `ExtensionTypeRegistry`, shadow dicts, helpers | +| `src/orcapod/extension_types/__init__.py` | Modify — export class, create module-level instance | +| `tests/test_extension_types/test_registry.py` | **Create** — full test suite | + +--- + +## Task 1: Fix `pyproject.toml` — restore Polars range constraint + +The Polars dependency was accidentally pinned to `==1.41.2` during exploration. Restore it to a range constraint. + +**Files:** +- Modify: `pyproject.toml` + +- [ ] **Step 1: Update the constraint** + +Open `pyproject.toml`. Find the line: +```toml +"polars==1.41.2", +``` +Replace with: +```toml +"polars>=1.36.0", +``` + +- [ ] **Step 2: Sync and verify** + +```bash +uv sync +uv run python -c "import polars as pl; print(pl.__version__); from polars import BaseExtension; print('BaseExtension OK')" +``` + +Expected output: +``` +1.41.2 +BaseExtension OK +``` + +- [ ] **Step 3: Commit** + +```bash +git add pyproject.toml uv.lock +git commit -m "chore(deps): restore polars>=1.36.0 range constraint (PLT-1653)" +``` + +--- + +## Task 2: Create `test_registry.py` and `registry.py` — pure-Python registry + +Write all tests that exercise the Python-only layer (dict storage, lookups, duplicate checking). No PA/Polars wiring yet — `register()` just populates the internal dicts. + +**Files:** +- Create: `tests/test_extension_types/test_registry.py` +- Create: `src/orcapod/extension_types/registry.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_extension_types/test_registry.py`: + +```python +"""Tests for ExtensionTypeRegistry.""" + +from __future__ import annotations + +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.protocols import ExtensionTypeConverter +from orcapod.extension_types.registry import ExtensionTypeRegistry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + """Unique extension name to avoid cross-test global-registry collisions.""" + return f"test.registry.{uuid.uuid4().hex[:8]}" + + +def _make_stub( + name: str | None = None, + storage: pa.DataType | None = None, + metadata: bytes | None = b"test.category", + py_type: type = str, +) -> ExtensionTypeConverter: + """Factory for minimal ExtensionTypeConverter conforming stubs.""" + _name = name or _unique_name() + _storage = storage if storage is not None else pa.large_utf8() + _metadata = metadata + _py_type = py_type + + class _Stub: + @property + def extension_name(self) -> str: + return _name + + @property + def extension_metadata(self) -> bytes | None: + return _metadata + + @property + def storage_type(self) -> pa.DataType: + return _storage + + @property + def python_type(self) -> type: + return _py_type + + def python_to_storage(self, value): + return str(value) + + def storage_to_python(self, storage_value): + return storage_value + + return _Stub() + + +# --------------------------------------------------------------------------- +# Pure-Python registry tests (no PA/Polars global state required) +# --------------------------------------------------------------------------- + +def test_register_stores_converter(): + registry = ExtensionTypeRegistry() + conv = _make_stub() + registry.register(conv) + assert registry.get_converter_for_name(conv.extension_name) is conv + + +def test_register_duplicate_raises(): + registry = ExtensionTypeRegistry() + name = _unique_name() + registry.register(_make_stub(name=name)) + with pytest.raises(ValueError, match=name): + registry.register(_make_stub(name=name)) + + +def test_get_converter_for_name_miss(): + registry = ExtensionTypeRegistry() + assert registry.get_converter_for_name("does.not.exist") is None + + +def test_get_converter_for_python_type_exact(): + registry = ExtensionTypeRegistry() + conv = _make_stub(py_type=bytes) + registry.register(conv) + assert registry.get_converter_for_python_type(bytes) is conv + + +def test_get_converter_for_python_type_subclass(): + class _Base: + pass + + class _Child(_Base): + pass + + registry = ExtensionTypeRegistry() + conv = _make_stub(py_type=_Base) + registry.register(conv) + assert registry.get_converter_for_python_type(_Child) is conv + + +def test_get_converter_for_python_type_miss(): + registry = ExtensionTypeRegistry() + assert registry.get_converter_for_python_type(int) is None + + +def test_has_extension_name(): + registry = ExtensionTypeRegistry() + conv = _make_stub() + assert not registry.has_extension_name(conv.extension_name) + registry.register(conv) + assert registry.has_extension_name(conv.extension_name) + + +def test_has_python_type(): + registry = ExtensionTypeRegistry() + conv = _make_stub(py_type=float) + assert not registry.has_python_type(float) + registry.register(conv) + assert registry.has_python_type(float) + + +def test_list_extension_names(): + registry = ExtensionTypeRegistry() + a = _make_stub() + b = _make_stub() + registry.register(a) + registry.register(b) + assert registry.list_extension_names() == [a.extension_name, b.extension_name] + + +def test_list_python_types(): + registry = ExtensionTypeRegistry() + a = _make_stub(py_type=bytes) + b = _make_stub(py_type=float) + registry.register(a) + registry.register(b) + assert registry.list_python_types() == [bytes, float] +``` + +- [ ] **Step 2: Run to confirm ImportError (registry module does not exist yet)** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v 2>&1 | head -20 +``` + +Expected: `ModuleNotFoundError: No module named 'orcapod.extension_types.registry'` + +- [ ] **Step 3: Create `src/orcapod/extension_types/registry.py`** + +```python +"""Registry for ExtensionTypeConverter instances. + +Registering a converter automatically registers the corresponding +extension type in both PyArrow's and Polars' global registries. +""" + +from __future__ import annotations + +import re + +import pyarrow as pa +import polars as pl + +from orcapod.extension_types.protocols import ExtensionTypeConverter + +# --------------------------------------------------------------------------- +# Shadow dicts — track what *we* have registered in the global registries. +# These are module-level singletons shared across all ExtensionTypeRegistry +# instances. We use our own dicts rather than querying library internals +# because neither PyArrow nor Polars exposes a stable public API for looking +# up a previously registered extension type by name. +# +# Limitation: types registered externally (directly via +# pa.register_extension_type / pl.register_extension_type, bypassing this +# module) will not appear here. A subsequent register() call for the same +# name will detect the conflict via the library-level error and raise, +# because without knowing what was registered externally we cannot guarantee +# the same extension name maps to the same Python class and underlying +# storage type — silently proceeding risks data corruption or misrouted +# conversions at read time. +# --------------------------------------------------------------------------- + +_ARROW_REGISTRY: dict[str, tuple[pa.DataType, bytes]] = {} +# extension_name -> (storage_type, metadata_bytes) + +_POLARS_REGISTRY: dict[str, tuple[pl.DataType, str | None]] = {} +# extension_name -> (pl_storage_dtype, metadata_str) + + +def _sanitize(name: str) -> str: + return re.sub(r"[^A-Za-z0-9]", "_", name) + + +def _register_arrow_ext_type(converter: ExtensionTypeConverter) -> None: + """Register a ``pa.ExtensionType`` subclass for *converter* in PyArrow's global registry.""" + name = converter.extension_name + metadata = converter.extension_metadata or b"" + storage = converter.storage_type + + if name in _ARROW_REGISTRY: + existing_storage, existing_metadata = _ARROW_REGISTRY[name] + if existing_storage == storage and existing_metadata == metadata: + return # idempotent — safe for module reload and test-suite reuse + raise ValueError( + f"Extension type '{name}' is already registered in the PyArrow global registry " + f"with different parameters.\n" + f" Registered: storage_type={existing_storage!r}, metadata={existing_metadata!r}\n" + f" Attempted: storage_type={storage!r}, metadata={metadata!r}" + ) + + _name, _storage, _metadata = name, storage, metadata + ArrowExtType = type( + f"_ArrowExt_{_sanitize(name)}", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _storage, _name), + "__arrow_ext_serialize__": lambda self: _metadata, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + + try: + pa.register_extension_type(ArrowExtType()) + except pa.lib.ArrowKeyError: + raise ValueError( + f"Extension type '{name}' is already registered in the PyArrow global registry " + f"by an external source. Cannot verify equivalence; orcapod requires exclusive " + f"ownership of extension type registrations to prevent data corruption or " + f"misrouted conversions. See PLT-1665 for future interop support." + ) from None + + _ARROW_REGISTRY[name] = (storage, metadata) + + +def _register_polars_ext_type(converter: ExtensionTypeConverter) -> None: + """Register a ``pl.BaseExtension`` subclass for *converter* in Polars' global registry.""" + name = converter.extension_name + metadata = converter.extension_metadata + metadata_str = metadata.decode("utf-8") if metadata else None + pl_storage = pl.from_arrow(pa.array([], type=converter.storage_type)).dtype + + if name in _POLARS_REGISTRY: + existing_storage, existing_meta = _POLARS_REGISTRY[name] + if existing_storage == pl_storage and existing_meta == metadata_str: + return # idempotent + raise ValueError( + f"Extension type '{name}' is already registered in the Polars global registry " + f"with different parameters.\n" + f" Registered: storage_dtype={existing_storage!r}, metadata={existing_meta!r}\n" + f" Attempted: storage_dtype={pl_storage!r}, metadata={metadata_str!r}" + ) + + _name, _pl_storage, _meta_str = name, pl_storage, metadata_str + PolarsExtType = type( + f"_PolarsExt_{_sanitize(name)}", + (pl.BaseExtension,), + { + "__init__": lambda self: pl.BaseExtension.__init__(self, _name, _pl_storage, _meta_str), + "ext_from_params": classmethod(lambda cls, n, s, m: cls()), + }, + ) + + try: + pl.register_extension_type(name, PolarsExtType) + except ValueError as exc: + raise ValueError( + f"Extension type '{name}' is already registered in the Polars global registry " + f"by an external source. Cannot verify equivalence; orcapod requires exclusive " + f"ownership of extension type registrations to prevent data corruption or " + f"misrouted conversions. See PLT-1665 for future interop support." + ) from exc + + _POLARS_REGISTRY[name] = (pl_storage, metadata_str) + + +class ExtensionTypeRegistry: + """Registry for ``ExtensionTypeConverter`` instances. + + Registering a converter automatically registers the corresponding + extension type in both PyArrow's and Polars' global registries. + + The primary lookup key is ``extension_name``; a secondary lookup by + ``python_type`` is provided for the write path. + + Example: + >>> registry = ExtensionTypeRegistry() + >>> registry.register(my_converter) + >>> conv = registry.get_converter_for_name("my.Type") + """ + + def __init__(self) -> None: + self._by_name: dict[str, ExtensionTypeConverter] = {} + self._by_python_type: dict[type, ExtensionTypeConverter] = {} + + def register(self, converter: ExtensionTypeConverter) -> None: + """Register *converter* and its PyArrow/Polars extension types. + + Args: + converter: An ``ExtensionTypeConverter`` instance to register. + + Raises: + ValueError: If ``converter.extension_name`` is already registered + in this registry instance. + ValueError: If the extension name is already in the PA or Polars + global registry with different parameters. + ValueError: If the extension name is already in the PA or Polars + global registry from an external source (equivalence cannot + be verified). + """ + name = converter.extension_name + if name in self._by_name: + raise ValueError( + f"Extension type '{name}' is already registered in this registry." + ) + self._by_name[name] = converter + self._by_python_type[converter.python_type] = converter + _register_arrow_ext_type(converter) + _register_polars_ext_type(converter) + + def get_converter_for_name(self, name: str) -> ExtensionTypeConverter | None: + """Return the converter registered under *name*, or ``None``.""" + return self._by_name.get(name) + + def get_converter_for_python_type(self, python_type: type) -> ExtensionTypeConverter | None: + """Return the converter for *python_type*, or ``None``. + + Checks exact match first, then falls back to an ``issubclass`` scan. + When multiple registered types are superclasses of *python_type*, the + one registered first wins (insertion-order dict, Python 3.7+). + """ + converter = self._by_python_type.get(python_type) + if converter is not None: + return converter + for registered_type, conv in self._by_python_type.items(): + if issubclass(python_type, registered_type): + return conv + return None + + def has_extension_name(self, name: str) -> bool: + """Return ``True`` if *name* is registered.""" + return name in self._by_name + + def has_python_type(self, python_type: type) -> bool: + """Return ``True`` if *python_type* (or a subclass) is registered.""" + return self.get_converter_for_python_type(python_type) is not None + + def list_extension_names(self) -> list[str]: + """Return all registered extension names in insertion order.""" + return list(self._by_name.keys()) + + def list_python_types(self) -> list[type]: + """Return all registered Python types in insertion order.""" + return list(self._by_python_type.keys()) +``` + +- [ ] **Step 4: Run the pure-Python tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v -k "not arrow and not polars and not round_trip and not parquet and not module_instance" +``` + +Expected: all 11 tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/registry.py tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add ExtensionTypeRegistry with pure-Python lookup (PLT-1653)" +``` + +--- + +## Task 3: Add PyArrow global registration tests + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Add the PyArrow tests** + +Append to `tests/test_extension_types/test_registry.py`: + +```python +# --------------------------------------------------------------------------- +# PyArrow global registry tests +# --------------------------------------------------------------------------- + +def test_register_populates_arrow_registry(): + """After register(), PA global registry contains the extension type.""" + conv = _make_stub() + registry = ExtensionTypeRegistry() + registry.register(conv) + + # If the name is registered, attempting to re-register it raises ArrowKeyError. + # This is the only stable public signal PyArrow provides. + class _Probe(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.large_utf8(), conv.extension_name) + def __arrow_ext_serialize__(self): + return b"" + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + with pytest.raises(pa.lib.ArrowKeyError): + pa.register_extension_type(_Probe()) + + +def test_register_arrow_global_collision_same_params_is_idempotent(): + """A second registry instance registering the same name+params succeeds silently.""" + name = _unique_name() + conv = _make_stub(name=name, storage=pa.large_utf8(), metadata=b"cat") + + ExtensionTypeRegistry().register(conv) # first — populates _ARROW_REGISTRY + ExtensionTypeRegistry().register(conv) # second — should not raise + + +def test_register_arrow_global_collision_different_storage_raises(): + """A second registry using the same name but different storage_type raises.""" + name = _unique_name() + ExtensionTypeRegistry().register(_make_stub(name=name, storage=pa.large_utf8())) + + with pytest.raises(ValueError, match=name): + ExtensionTypeRegistry().register(_make_stub(name=name, storage=pa.large_binary())) + + +def test_register_arrow_global_collision_different_metadata_raises(): + """A second registry using the same name but different metadata raises.""" + name = _unique_name() + ExtensionTypeRegistry().register(_make_stub(name=name, metadata=b"original")) + + with pytest.raises(ValueError, match=name): + ExtensionTypeRegistry().register(_make_stub(name=name, metadata=b"different")) + + +def test_register_arrow_external_registration_raises(): + """A name registered directly with PyArrow (bypassing our registry) raises on register().""" + name = _unique_name() + + class _External(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + def __arrow_ext_serialize__(self): + return b"" + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_External()) # bypass our registry + + with pytest.raises(ValueError, match="external source"): + ExtensionTypeRegistry().register(_make_stub(name=name)) +``` + +- [ ] **Step 2: Run all tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: all tests pass (the PyArrow registration was already wired in Task 2). + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_registry.py +git commit -m "test(extension_types): add PyArrow global registry tests (PLT-1653)" +``` + +--- + +## Task 4: Add Polars global registration tests + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Add the Polars tests** + +Append to `tests/test_extension_types/test_registry.py`: + +```python +# --------------------------------------------------------------------------- +# Polars global registry tests +# --------------------------------------------------------------------------- + +def test_register_populates_polars_registry(): + """After register(), pl.from_arrow on an ext-type array yields a BaseExtension dtype.""" + conv = _make_stub(storage=pa.large_utf8()) + registry = ExtensionTypeRegistry() + registry.register(conv) + + # Build a PA extension array using the registered type. + # We need to get the registered ArrowExtType instance; the simplest way is + # to read it from _ARROW_REGISTRY shadow dict via the type's name in a PA array. + from orcapod.extension_types.registry import _ARROW_REGISTRY + assert conv.extension_name in _ARROW_REGISTRY + + # Create a storage array and cast it to the ext type to get a properly typed array. + # (The ArrowExtType class is not directly accessible from outside, but we can + # construct an array through the IPC round-trip or via the registered type.) + # Simplest: use pl.from_arrow on a storage array and check the dtype AFTER + # registering — the series dtype should be our BaseExtension subclass. + import warnings + arr = pa.array(["hello"], type=pa.large_utf8()) + # The ext type is registered, so building an array with it works. + # We access it via the _ARROW_REGISTRY which stores (storage_type, metadata). + # The actual class instance is what was registered; we verify Polars recognises it + # by checking the dtype returned from pl.from_arrow on an ext-typed array. + # Build ext array via cast on a pre-registered type instance from the module. + from orcapod.extension_types import registry as reg_mod + # Reconstruct the ArrowExtType by checking what _ARROW_REGISTRY has, then + # building a matching IPC array. Easiest: use the existing ArrowExtType class + # by catching it from PA global via unregister/re-register trick — but that's + # invasive. Instead, just verify via _POLARS_REGISTRY dict directly. + from orcapod.extension_types.registry import _POLARS_REGISTRY + assert conv.extension_name in _POLARS_REGISTRY + stored_storage, stored_meta = _POLARS_REGISTRY[conv.extension_name] + assert stored_storage == pl.String + assert stored_meta == "test.category" + + +def test_register_polars_global_collision_same_params_is_idempotent(): + """A second registry instance registering the same name+params succeeds silently.""" + name = _unique_name() + conv = _make_stub(name=name, storage=pa.large_utf8(), metadata=b"cat") + + ExtensionTypeRegistry().register(conv) + ExtensionTypeRegistry().register(conv) # should not raise + + +def test_register_polars_global_collision_different_storage_raises(): + """A second registry using the same name but different storage_type raises.""" + name = _unique_name() + ExtensionTypeRegistry().register(_make_stub(name=name, storage=pa.large_utf8())) + + with pytest.raises(ValueError, match=name): + ExtensionTypeRegistry().register(_make_stub(name=name, storage=pa.large_binary())) + + +def test_register_polars_external_registration_raises(): + """A name registered directly with Polars (bypassing our registry) raises on register().""" + name = _unique_name() + + class _ExternalPL(pl.BaseExtension): + def __init__(self): + super().__init__(name, pl.String, None) + @classmethod + def ext_from_params(cls, n, s, m): + return cls() + + # Also register in PA first so we don't hit the PA external-registration error + class _ExternalPA(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + def __arrow_ext_serialize__(self): + return b"" + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_ExternalPA()) + pl.register_extension_type(name, _ExternalPL) + + with pytest.raises(ValueError, match="external source"): + ExtensionTypeRegistry().register(_make_stub(name=name)) +``` + +- [ ] **Step 2: Run all tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_registry.py +git commit -m "test(extension_types): add Polars global registry tests (PLT-1653)" +``` + +--- + +## Task 5: End-to-end integration tests + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Add the integration tests** + +Append to `tests/test_extension_types/test_registry.py`: + +```python +# --------------------------------------------------------------------------- +# End-to-end integration tests +# --------------------------------------------------------------------------- + +import warnings +import tempfile +import pathlib +import pyarrow.parquet as pq + + +class _Color: + """Minimal Python class used to exercise the converter contract end-to-end.""" + def __init__(self, hex_str: str) -> None: + self.hex_str = hex_str + def __eq__(self, other: object) -> bool: + return isinstance(other, _Color) and self.hex_str == other.hex_str + def __repr__(self) -> str: + return f"Color({self.hex_str!r})" + + +def _make_color_converter() -> ExtensionTypeConverter: + """ExtensionTypeConverter for _Color, backed by pa.large_utf8() storage.""" + _name = _unique_name() + + class _ColorConverter: + @property + def extension_name(self) -> str: + return _name + @property + def extension_metadata(self) -> bytes | None: + return b"test.color" + @property + def storage_type(self) -> pa.DataType: + return pa.large_utf8() + @property + def python_type(self) -> type: + return _Color + def python_to_storage(self, value: _Color) -> str: + return value.hex_str + def storage_to_python(self, storage_value: str) -> _Color: + return _Color(storage_value) + + return _ColorConverter() + + +def _build_ext_array( + converter: ExtensionTypeConverter, + values: list, +) -> pa.Array: + """Build a PA extension array from Python values using the converter.""" + from orcapod.extension_types.registry import _ARROW_REGISTRY + + storage_values = [converter.python_to_storage(v) for v in values] + storage_arr = pa.array(storage_values, type=converter.storage_type) + + # Retrieve the registered ArrowExtType instance via a fresh array cast. + # We use the PA global registry indirectly: _ARROW_REGISTRY confirms + # the type is registered; we then reconstruct the ext array by building + # a new subclass instance (same extension_name → PA resolves to registered class). + import re + _name = converter.extension_name + _storage = converter.storage_type + _metadata = converter.extension_metadata or b"" + _sanitized = re.sub(r"[^A-Za-z0-9]", "_", _name) + + ArrowExtType = type( + f"_ArrowExt_{_sanitized}_probe", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _storage, _name), + "__arrow_ext_serialize__": lambda self: _metadata, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + # This will be caught as "already registered" internally; we instantiate + # separately — PyArrow resolves the extension by name, not by class identity. + ext_type_instance = ArrowExtType() + return storage_arr.cast(ext_type_instance) + + +def test_python_class_round_trip(): + """Python objects → Arrow extension array → Python objects via converter methods.""" + conv = _make_color_converter() + registry = ExtensionTypeRegistry() + registry.register(conv) + + originals = [_Color("#ff0000"), _Color("#00ff00"), _Color("#0000ff")] + ext_arr = _build_ext_array(conv, originals) + + # Decode back + storage_back = ext_arr.cast(conv.storage_type) + recovered = [conv.storage_to_python(v.as_py()) for v in storage_back] + assert recovered == originals + + +def test_arrow_polars_round_trip(): + """PA ext array → pl.from_arrow → to_arrow() preserves extension type and values.""" + conv = _make_color_converter() + registry = ExtensionTypeRegistry() + registry.register(conv) + + originals = [_Color("#aabbcc"), _Color("#112233")] + ext_arr = _build_ext_array(conv, originals) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pl_series = pl.from_arrow(ext_arr) + + assert isinstance(pl_series.dtype, pl.BaseExtension) + assert pl_series.dtype.ext_name() == conv.extension_name + + arr_back = pl_series.to_arrow() + assert arr_back.type.extension_name == conv.extension_name + + recovered = [conv.storage_to_python(v.as_py()) for v in arr_back.cast(conv.storage_type)] + assert recovered == originals + + +def test_parquet_round_trip(): + """PA ext array → Parquet → read back via PyArrow; extension type and values preserved.""" + conv = _make_color_converter() + registry = ExtensionTypeRegistry() + registry.register(conv) + + originals = [_Color("#deadbe"), _Color("#cafeba")] + ext_arr = _build_ext_array(conv, originals) + schema = pa.schema([pa.field("color", ext_arr.type), pa.field("id", pa.int32())]) + table = pa.table( + {"color": ext_arr, "id": pa.array([1, 2], type=pa.int32())}, + schema=schema, + ) + + with tempfile.TemporaryDirectory() as tmp: + path = pathlib.Path(tmp) / "test.parquet" + pq.write_table(table, path) + table_back = pq.read_table(path) + + assert table_back.schema.field("color").type.extension_name == conv.extension_name + recovered = [ + conv.storage_to_python(v.as_py()) + for v in table_back.column("color").cast(conv.storage_type) + ] + assert recovered == originals +``` + +- [ ] **Step 2: Run all tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_registry.py +git commit -m "test(extension_types): add end-to-end integration tests (PLT-1653)" +``` + +--- + +## Task 6: Update `extension_types/__init__.py` + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` +- Modify: `src/orcapod/extension_types/__init__.py` + +- [ ] **Step 1: Write the failing test** + +Append to `tests/test_extension_types/test_registry.py`: + +```python +# --------------------------------------------------------------------------- +# Module-level instance test +# --------------------------------------------------------------------------- + +def test_extension_type_registry_module_instance(): + """extension_types.extension_type_registry is an ExtensionTypeRegistry, starts empty.""" + from orcapod import extension_types + assert isinstance(extension_types.extension_type_registry, ExtensionTypeRegistry) + # PLT-1653 scope: no built-in converters registered yet (that is PLT-1656) + assert extension_types.extension_type_registry.list_extension_names() == [] +``` + +- [ ] **Step 2: Run to confirm it fails** + +```bash +uv run pytest tests/test_extension_types/test_registry.py::test_extension_type_registry_module_instance -v +``` + +Expected: `AttributeError: module 'orcapod.extension_types' has no attribute 'extension_type_registry'` + +- [ ] **Step 3: Update `src/orcapod/extension_types/__init__.py`** + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for converters that map +between Python objects and their Arrow extension type storage representation. + +The module-level ``extension_type_registry`` instance is the process default. +Built-in registrations (``Path``, ``UPath``, ``UUID``) are added by PLT-1656. +``DataContext`` wiring is added by PLT-1660. +""" + +from .protocols import ExtensionTypeConverter +from .registry import ExtensionTypeRegistry + +extension_type_registry = ExtensionTypeRegistry() + +__all__ = [ + "ExtensionTypeConverter", + "ExtensionTypeRegistry", + "extension_type_registry", +] +``` + +- [ ] **Step 4: Run all tests** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all tests pass. + +- [ ] **Step 5: Run the full test suite to check for regressions** + +```bash +uv run pytest --tb=short -q +``` + +Expected: no new failures. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/__init__.py tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): export ExtensionTypeRegistry and module-level instance (PLT-1653)" +``` + +--- + +## Final check + +```bash +uv run pytest tests/test_extension_types/ -v --tb=short +``` + +All tests should pass. The PR targets `dev`. diff --git a/superpowers/plans/2026-06-14-plt-1654-schema-walker.md b/superpowers/plans/2026-06-14-plt-1654-schema-walker.md new file mode 100644 index 00000000..1e4c25e1 --- /dev/null +++ b/superpowers/plans/2026-06-14-plt-1654-schema-walker.md @@ -0,0 +1,660 @@ +# PLT-1654: Schema Walker Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add `src/orcapod/extension_types/schema_walker.py` — a pure discovery utility that walks a `pa.Schema` or `pa.Field` recursively and returns all extension-typed fields as `ExtensionTypeInfo` instances. + +**Architecture:** Three-layer design: `ExtensionTypeInfo` frozen dataclass as the return type; `_detect_extension` handles single-field two-channel detection; `_collect` drives recursive container descent with inline deduplication. Two public entry points — `walk_schema` and `walk_field` — each initialise a fresh `seen` set and delegate to `_collect`. + +**Tech Stack:** PyArrow ≥ 20.0.0, Python 3.11+, pytest, uv + +--- + +## File Map + +| File | Change | +|---|---| +| `src/orcapod/extension_types/schema_walker.py` | **New** — full module | +| `src/orcapod/extension_types/__init__.py` | Additive — append three new exports | +| `tests/test_extension_types/test_schema_walker.py` | **New** — full test suite | + +No other files are touched. + +--- + +## Task 1: Core module — `ExtensionTypeInfo`, detection, top-level walk, deduplication + +This task produces the full `schema_walker.py`. Container recursion (struct/list/map) is +added in Task 2. After this task, `walk_schema` and `walk_field` work for top-level +fields only; nesting tests are left for Task 2. + +**Files:** +- Create: `src/orcapod/extension_types/schema_walker.py` +- Create: `tests/test_extension_types/test_schema_walker.py` + +--- + +- [ ] **Step 1.1: Write the failing tests** + +Create `tests/test_extension_types/test_schema_walker.py` with this content: + +```python +"""Tests for schema_walker — recursive Arrow extension type discovery.""" + +from __future__ import annotations + +import re +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.schema_walker import ( + ExtensionTypeInfo, + walk_field, + walk_schema, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _unique_name() -> str: + """Return a unique extension name to avoid cross-test collisions.""" + return f"test.walker.{uuid.uuid4().hex[:8]}" + + +def _make_reg_field( + field_name: str, + ext_name: str, + storage: pa.DataType | None = None, + metadata: bytes = b"test.cat", +) -> pa.Field: + """Create a ``pa.Field`` with an in-memory ``pa.ExtensionType`` (registered channel). + + The extension type is NOT registered in PyArrow's global registry — this + is intentional. ``pa.types.is_extension(field.type)`` returns ``True`` + for any ``pa.ExtensionType`` instance regardless of global registration. + """ + _n = ext_name + _s = storage if storage is not None else pa.large_utf8() + _m = metadata + ExtType = type( + f"_RegExt_{re.sub(r'[^A-Za-z0-9]', '_', ext_name)}", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: _m, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + return pa.field(field_name, ExtType()) + + +def _make_unreg_field( + field_name: str, + ext_name: str, + storage: pa.DataType | None = None, + metadata: bytes = b"test.cat", +) -> pa.Field: + """Create a ``pa.Field`` with raw Arrow extension metadata (unregistered channel).""" + _s = storage if storage is not None else pa.large_utf8() + return pa.field( + field_name, + _s, + metadata={ + b"ARROW:extension:name": ext_name.encode(), + b"ARROW:extension:metadata": metadata, + }, + ) + + +# --------------------------------------------------------------------------- +# Task 1 tests: top-level detection and deduplication +# --------------------------------------------------------------------------- + + +def test_empty_schema(): + result = walk_schema(pa.schema([])) + assert result == [] + + +def test_no_extension_types(): + schema = pa.schema([ + pa.field("x", pa.int64()), + pa.field("y", pa.large_utf8()), + ]) + assert walk_schema(schema) == [] + + +def test_top_level_registered(): + name = _unique_name() + schema = pa.schema([_make_reg_field("col", name, metadata=b"my.cat")]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + assert result[0].storage_type == pa.large_utf8() + + +def test_top_level_unregistered(): + name = _unique_name() + schema = pa.schema([_make_unreg_field("col", name, metadata=b"my.cat")]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + assert result[0].storage_type == pa.large_utf8() + + +def test_empty_metadata_normalised_to_none_registered(): + """b'' from __arrow_ext_serialize__ is normalised to None.""" + name = _unique_name() + _n, _s = name, pa.large_utf8() + ExtType = type( + "_EmptyMetaExt", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: b"", + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + result = walk_field(pa.field("col", ExtType())) + assert len(result) == 1 + assert result[0].extension_metadata is None + + +def test_empty_metadata_normalised_to_none_unregistered(): + """b'' ARROW:extension:metadata value is normalised to None.""" + name = _unique_name() + field = pa.field( + "col", + pa.large_utf8(), + metadata={ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }, + ) + result = walk_field(field) + assert len(result) == 1 + assert result[0].extension_metadata is None + + +def test_walk_field_returns_single_field_result(): + name = _unique_name() + field = _make_reg_field("col", name, metadata=b"cat") + result = walk_field(field) + assert len(result) == 1 + assert result[0].extension_name == name + + +def test_deduplication(): + """Same (extension_name, extension_metadata) in two columns → one result.""" + name = _unique_name() + meta = b"test.cat" + _n, _m, _s = name, meta, pa.large_utf8() + ExtType = type( + "_DupExt", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: _m, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + schema = pa.schema([ + pa.field("col_a", ExtType()), + pa.field("col_b", ExtType()), + ]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == meta +``` + +- [ ] **Step 1.2: Run tests to verify they all fail** + +```bash +cd /path/to/orcapod-python +uv run pytest tests/test_extension_types/test_schema_walker.py -v 2>&1 | head -30 +``` + +Expected: `ModuleNotFoundError` or `ImportError` — `schema_walker` does not exist yet. + +- [ ] **Step 1.3: Implement `schema_walker.py`** + +Create `src/orcapod/extension_types/schema_walker.py` with this content: + +```python +"""Recursive Arrow schema walker for extension type discovery. + +Given a ``pa.Schema`` or a single ``pa.Field``, walks the Arrow type tree +recursively and returns all extension-typed fields found at any depth of +nesting (struct, list, map, etc.). + +This is a pure discovery utility — it never triggers any registration. +""" + +from __future__ import annotations + +import dataclasses + +import pyarrow as pa + + +@dataclasses.dataclass(frozen=True) +class ExtensionTypeInfo: + """Metadata for a single Arrow extension type found in a schema. + + Attributes: + extension_name: The extension type's unique name stored as + ``ARROW:extension:name`` (e.g. ``"pathlib.Path"``). + extension_metadata: The category tag stored as + ``ARROW:extension:metadata`` (e.g. ``b"orcapod.dataclass"``). + ``None`` when absent or serialised as empty bytes. + storage_type: The underlying Arrow storage type + (e.g. ``pa.large_string()``). + """ + + extension_name: str + extension_metadata: bytes | None + storage_type: pa.DataType + + +def walk_schema(schema: pa.Schema) -> list[ExtensionTypeInfo]: + """Walk *schema* and return all extension types found, deduplicated. + + Iterates every top-level field and descends recursively into struct, + list, and map container types. The result is deduplicated by + ``(extension_name, extension_metadata)``; the first occurrence of each + pair is kept. + + Args: + schema: A PyArrow schema to inspect. + + Returns: + Deduplicated list of ``ExtensionTypeInfo`` in depth-first, + first-seen order. + """ + seen: set[tuple[str, bytes | None]] = set() + results: list[ExtensionTypeInfo] = [] + for i in range(schema.num_fields): + _collect(schema.field(i), seen, results) + return results + + +def walk_field(field: pa.Field) -> list[ExtensionTypeInfo]: + """Walk *field*'s type tree and return all extension types found, deduplicated. + + Args: + field: A PyArrow field to inspect. + + Returns: + Deduplicated list of ``ExtensionTypeInfo`` in depth-first, + first-seen order. + """ + seen: set[tuple[str, bytes | None]] = set() + results: list[ExtensionTypeInfo] = [] + _collect(field, seen, results) + return results + + +def _collect( + field: pa.Field, + seen: set[tuple[str, bytes | None]], + results: list[ExtensionTypeInfo], +) -> None: + """Recursively walk *field* and accumulate ``ExtensionTypeInfo`` into *results*. + + Mutates *seen* and *results* in place. Stops descending once a field is + identified as extension-typed — the storage type of an extension type is + not descended into. + + Args: + field: The field to inspect. + seen: Deduplication set of ``(extension_name, extension_metadata)`` + pairs already appended to *results*. + results: Accumulator list. + """ + info = _detect_extension(field) + if info is not None: + key = (info.extension_name, info.extension_metadata) + if key not in seen: + seen.add(key) + results.append(info) + return + + t = field.type + if pa.types.is_struct(t): + for i in range(t.num_fields): + _collect(t.field(i), seen, results) + elif ( + pa.types.is_list(t) + or pa.types.is_large_list(t) + or pa.types.is_fixed_size_list(t) + or pa.types.is_list_view(t) + or pa.types.is_large_list_view(t) + ): + _collect(t.value_field, seen, results) + elif pa.types.is_map(t): + key_field = getattr(t, "key_field", None) + item_field = getattr(t, "item_field", None) + if key_field is not None: + _collect(key_field, seen, results) + if item_field is not None: + _collect(item_field, seen, results) + + +def _detect_extension(field: pa.Field) -> ExtensionTypeInfo | None: + """Extract ``ExtensionTypeInfo`` from *field*, or ``None`` if not extension-typed. + + Checks two channels in order: + + 1. **Registered channel** — ``pa.types.is_extension(field.type)`` is + true. The Python type object carries the name, serialised metadata, + and storage type. + 2. **Unregistered channel** — ``field.metadata`` contains + ``b"ARROW:extension:name"``. The type survived a Parquet/IPC + round-trip without being registered in this process. + + In both cases empty bytes metadata (``b""``) is normalised to ``None``. + + Args: + field: The field to inspect. + + Returns: + ``ExtensionTypeInfo`` if the field is extension-typed, else ``None``. + """ + if pa.types.is_extension(field.type): + ext_type = field.type + raw_meta = ext_type.__arrow_ext_serialize__() + return ExtensionTypeInfo( + extension_name=ext_type.extension_name, + extension_metadata=raw_meta or None, + storage_type=ext_type.storage_type, + ) + + if field.metadata and b"ARROW:extension:name" in field.metadata: + name = field.metadata[b"ARROW:extension:name"].decode("utf-8") + raw_meta = field.metadata.get(b"ARROW:extension:metadata") + return ExtensionTypeInfo( + extension_name=name, + extension_metadata=raw_meta or None, + storage_type=field.type, + ) + + return None +``` + +- [ ] **Step 1.4: Run Task 1 tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_schema_walker.py -v -k "empty_schema or no_extension or top_level or empty_metadata or walk_field or deduplication" +``` + +Expected: all 8 tests PASS. + +- [ ] **Step 1.5: Commit** + +```bash +git add src/orcapod/extension_types/schema_walker.py tests/test_extension_types/test_schema_walker.py +git commit -m "feat(extension_types): add schema_walker with ExtensionTypeInfo and top-level detection (PLT-1654)" +``` + +--- + +## Task 2: Container recursion — struct, list, map, nested combinations + +This task adds the nesting tests and verifies the container recursion already present in +`_collect` (written in Task 1) handles them correctly. + +**Files:** +- Modify: `tests/test_extension_types/test_schema_walker.py` — append new tests + +--- + +- [ ] **Step 2.1: Append the nesting tests** + +Append to `tests/test_extension_types/test_schema_walker.py`: + +```python +# --------------------------------------------------------------------------- +# Task 2 tests: container recursion +# --------------------------------------------------------------------------- + + +def test_list_of_registered(): + """Registered extension type as the value field of a list.""" + name = _unique_name() + value_field = _make_reg_field("item", name, metadata=b"my.cat") + list_field = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([list_field])) + assert len(result) == 1 + assert result[0].extension_name == name + + +def test_list_of_unregistered(): + """Unregistered extension type as the value field of a list.""" + name = _unique_name() + value_field = _make_unreg_field("item", name, metadata=b"my.cat") + list_field = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([list_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_struct_containing_registered(): + """Registered extension type as a field inside a struct.""" + name = _unique_name() + struct_field = pa.field( + "col", + pa.struct([ + _make_reg_field("a", name, metadata=b"my.cat"), + pa.field("b", pa.int64()), + ]), + ) + result = walk_schema(pa.schema([struct_field])) + assert len(result) == 1 + assert result[0].extension_name == name + + +def test_struct_containing_unregistered(): + """Unregistered extension type as a field inside a struct.""" + name = _unique_name() + struct_field = pa.field( + "col", + pa.struct([ + _make_unreg_field("a", name, metadata=b"my.cat"), + pa.field("b", pa.int64()), + ]), + ) + result = walk_schema(pa.schema([struct_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_nested_list_struct(): + """Registered extension type nested inside list>.""" + name = _unique_name() + struct_type = pa.struct([ + _make_reg_field("x", name, metadata=b"deep.cat"), + pa.field("y", pa.int32()), + ]) + value_field = pa.field("item", struct_type) + col = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([col])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"deep.cat" + + +def test_map_type(): + """Extension type as the item type of a map (registered channel).""" + name = _unique_name() + _n, _m, _s = name, b"map.cat", pa.large_utf8() + # Build a pa.ExtensionType instance — it IS a pa.DataType and can be + # passed directly to pa.map_() as the item type. + ExtType = type( + "_MapItemExt", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: _m, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + map_field = pa.field("col", pa.map_(pa.large_utf8(), ExtType())) + result = walk_schema(pa.schema([map_field])) + # _collect uses getattr(t, "item_field") to retrieve the item pa.Field. + # pa.types.is_extension(item_field.type) will be True for the ExtType above. + assert any(r.extension_name == name for r in result) +``` + +- [ ] **Step 2.2: Run the nesting tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_schema_walker.py -v -k "list_of or struct_containing or nested or map_type" +``` + +Expected: all 6 tests PASS. The recursion was already written in `_collect` in Task 1. + +If `test_map_type` fails because `key_field` / `item_field` are not available on `MapType` +in this PyArrow version, skip it with `@pytest.mark.skip` and open a follow-up note. + +- [ ] **Step 2.3: Run the full test file to confirm no regressions** + +```bash +uv run pytest tests/test_extension_types/test_schema_walker.py -v +``` + +Expected: all 14 tests PASS. + +- [ ] **Step 2.4: Commit** + +```bash +git add tests/test_extension_types/test_schema_walker.py +git commit -m "test(extension_types): add nesting and map tests for schema_walker (PLT-1654)" +``` + +--- + +## Task 3: Export from `__init__.py` + +**Files:** +- Modify: `src/orcapod/extension_types/__init__.py` + +--- + +- [ ] **Step 3.1: Update `__init__.py`** + +Open `src/orcapod/extension_types/__init__.py`. It currently reads: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for converters that map +between Python objects and their Arrow extension type storage representation. + +The module-level `default_extension_type_registry` instance is the process default. +Built-in registrations (`Path`, `UPath`, `UUID`) are added by PLT-1656. +`DataContext` wiring is added by PLT-1660. +""" + +from __future__ import annotations + +from .protocols import ExtensionTypeConverter +from .registry import ExtensionTypeRegistry + +default_extension_type_registry = ExtensionTypeRegistry() + +__all__ = [ + "ExtensionTypeConverter", + "ExtensionTypeRegistry", + "default_extension_type_registry", +] +``` + +Replace the entire file with: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for converters that map +between Python objects and their Arrow extension type storage representation. + +The module-level `default_extension_type_registry` instance is the process default. +Built-in registrations (`Path`, `UPath`, `UUID`) are added by PLT-1656. +`DataContext` wiring is added by PLT-1660. +""" + +from __future__ import annotations + +from .protocols import ExtensionTypeConverter +from .registry import ExtensionTypeRegistry +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +default_extension_type_registry = ExtensionTypeRegistry() + +__all__ = [ + "ExtensionTypeConverter", + "ExtensionTypeRegistry", + "default_extension_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +- [ ] **Step 3.2: Verify the exports are importable** + +```bash +uv run python -c " +from orcapod.extension_types import ExtensionTypeInfo, walk_schema, walk_field +import pyarrow as pa +schema = pa.schema([pa.field('x', pa.int64())]) +print(walk_schema(schema)) # should print [] +print('OK') +" +``` + +Expected output: +``` +[] +OK +``` + +- [ ] **Step 3.3: Run the full test suite for `test_extension_types/`** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all tests PASS (no regressions in `test_protocols.py` or `test_registry.py`). + +- [ ] **Step 3.4: Commit** + +```bash +git add src/orcapod/extension_types/__init__.py +git commit -m "feat(extension_types): export ExtensionTypeInfo, walk_schema, walk_field (PLT-1654)" +``` + +--- + +## Done + +After Task 3: + +- `src/orcapod/extension_types/schema_walker.py` is complete with `ExtensionTypeInfo`, + `walk_schema`, `walk_field`, `_collect`, and `_detect_extension`. +- `ExtensionTypeInfo`, `walk_schema`, `walk_field` are exported from + `orcapod.extension_types`. +- 14 tests in `tests/test_extension_types/test_schema_walker.py` all pass. +- No existing code was modified; no regressions in other `test_extension_types/` tests. + +Create a PR targeting the `extension-type-system` branch (not `dev`). diff --git a/superpowers/plans/2026-06-14-plt-1655-database-hooks.md b/superpowers/plans/2026-06-14-plt-1655-database-hooks.md new file mode 100644 index 00000000..e1f8f85a --- /dev/null +++ b/superpowers/plans/2026-06-14-plt-1655-database-hooks.md @@ -0,0 +1,1300 @@ +# PLT-1655: Database Hooks Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add the peek-schema → register → read pattern to both database read paths so that Arrow extension types found in stored schemas are automatically registered before any table data is returned. + +**Architecture:** A stateless `ensure_extensions_registered(schema)` hook in `database_hooks.py` walks the schema using the existing `walk_schema` utility, then delegates each discovered type to `LogicalTypeRegistry.prepare_extension_type`. The registry owns all dispatch logic: it checks its own `_by_arrow_name` dict as a per-process cache (step 1), then parses JSON metadata, dispatches to a `LogicalTypeFactory` by category string, and calls `self.register()`. Two new protocols (`LogicalTypeFactory`) and two new methods on `LogicalTypeRegistry` (`register_logical_type_factory`, `prepare_extension_type`) complete the contract. + +**Tech Stack:** Python 3.12+, PyArrow, Polars, `pytest`, `json` (stdlib) + +--- + +## File Map + +| File | Action | Responsibility | +|---|---|---| +| `src/orcapod/extension_types/protocols.py` | Modify | Add `LogicalTypeFactory` Protocol | +| `src/orcapod/extension_types/registry.py` | Modify | Add logging, `import json`, `LogicalTypeFactory` import, `_factories` dict, `register_logical_type_factory`, `prepare_extension_type`, move `default_logical_type_registry` singleton here | +| `src/orcapod/extension_types/__init__.py` | Modify | Import `default_logical_type_registry` from `.registry`, add `LogicalTypeFactory` and `ensure_extensions_registered` to exports | +| `src/orcapod/extension_types/database_hooks.py` | **Create** | `ensure_extensions_registered(schema)` stateless hook | +| `src/orcapod/databases/delta_lake_databases.py` | Modify | Add `ensure_extensions_registered` call in `_read_delta_table` | +| `src/orcapod/databases/connector_arrow_database.py` | Modify | Add `import logging`, `logger`, `ensure_extensions_registered` call in `_get_committed_table` | +| `tests/test_extension_types/test_protocols.py` | Modify | Add `LogicalTypeFactory` conformance tests | +| `tests/test_extension_types/test_registry.py` | Modify | Add `_make_stub_factory` helper + 9 tests for new registry methods | +| `tests/test_extension_types/test_database_hooks.py` | **Create** | 9 tests for `ensure_extensions_registered` | + +--- + +## Task 1: `LogicalTypeFactory` Protocol + registry logging infrastructure + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py` +- Modify: `src/orcapod/extension_types/registry.py` (lines 1–21: imports section) +- Test: `tests/test_extension_types/test_protocols.py` + +- [ ] **Step 1: Write the failing tests** + +Add to `tests/test_extension_types/test_protocols.py` — after the existing `_StubLogicalType` class: + +```python +class _StubFactory: + """Minimal conforming implementation of LogicalTypeFactory for use in tests.""" + + def create_logical_type(self, arrow_extension_name, storage_type, metadata): + return _StubLogicalType() + + +def test_logical_type_factory_protocol_is_importable(): + """LogicalTypeFactory can be imported from extension_types.protocols.""" + from orcapod.extension_types.protocols import LogicalTypeFactory + assert LogicalTypeFactory is not None + + +def test_logical_type_factory_conforming_class_satisfies_protocol(): + """A conforming class is recognized as a LogicalTypeFactory instance.""" + from orcapod.extension_types.protocols import LogicalTypeFactory + assert isinstance(_StubFactory(), LogicalTypeFactory) + + +def test_logical_type_factory_create_returns_logical_type(): + """A conforming factory returns a LogicalType from create_logical_type.""" + from orcapod.extension_types.protocols import LogicalTypeFactory, LogicalType + factory: LogicalTypeFactory = _StubFactory() + result = factory.create_logical_type( + "test.ext", pa.large_utf8(), {"category": "Test"} + ) + assert isinstance(result, LogicalType) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +``` +uv run pytest tests/test_extension_types/test_protocols.py -v -k "factory" +``` + +Expected: FAIL — `ImportError: cannot import name 'LogicalTypeFactory' from 'orcapod.extension_types.protocols'` + +- [ ] **Step 3: Add `LogicalTypeFactory` to protocols.py** + +Open `src/orcapod/extension_types/protocols.py`. After the closing `...` of `LogicalType`, append: + +```python +@runtime_checkable +class LogicalTypeFactory(Protocol): + """Protocol for factories that auto-construct ``LogicalType`` instances from Arrow schema metadata. + + A ``LogicalTypeFactory`` constructs a ``LogicalType`` from the Arrow extension + type name, its underlying storage type, and the full parsed JSON metadata dict. + The dispatch key (``"category"`` value from the metadata JSON) that routes to this + factory is declared at registration time via + ``LogicalTypeRegistry.register_logical_type_factory``; the factory itself has no + knowledge of its dispatch key but receives the full metadata dict so it can read + additional hints beyond ``"category"``. + + This protocol is ``@runtime_checkable``, consistent with ``LogicalType``. + """ + + def create_logical_type( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict, + ) -> LogicalType: + """Construct a ``LogicalType`` for the given Arrow extension name and storage type. + + Args: + arrow_extension_name: The Arrow extension type name extracted from the + schema (i.e. the value of ``ARROW:extension:name`` field metadata). + storage_type: The underlying Arrow storage type for this extension field. + metadata: The full parsed JSON metadata dict. Always contains at least a + ``"category"`` key. May contain additional keys the factory uses (e.g. + ``"protocol"``, ``"pydantic_version"``). + + Returns: + A fully constructed ``LogicalType`` ready to be passed to + ``LogicalTypeRegistry.register()``. + + Raises: + ValueError: If this factory cannot construct a logical type for the given + extension name (e.g. the Python class cannot be resolved by name). + """ + ... +``` + +- [ ] **Step 4: Add logging infrastructure to registry.py** + +Open `src/orcapod/extension_types/registry.py`. The current imports block starts at line 1: + +```python +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from orcapod.extension_types.protocols import LogicalType +``` + +Replace the imports block with: + +```python +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING + +from orcapod.extension_types.protocols import LogicalType, LogicalTypeFactory +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) +``` + +- [ ] **Step 5: Run tests to verify they pass** + +``` +uv run pytest tests/test_extension_types/test_protocols.py -v +``` + +Expected: all tests PASS (including the 3 new factory tests) + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py src/orcapod/extension_types/registry.py tests/test_extension_types/test_protocols.py +git commit -m "feat(extension_types): add LogicalTypeFactory protocol and registry logging setup" +``` + +--- + +## Task 2: Move `default_logical_type_registry` singleton to registry.py + +**Context:** `database_hooks.py` (Task 5) will import `default_logical_type_registry` from `registry.py`. If it imported from `orcapod.extension_types` (the package `__init__.py`), a circular import would occur because `__init__.py` will later import from `database_hooks`. Moving the singleton to `registry.py` breaks the cycle. + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` (add singleton at bottom) +- Modify: `src/orcapod/extension_types/__init__.py` (import instead of create) +- Test: `tests/test_extension_types/test_registry.py` (add one new import-path test) + +- [ ] **Step 1: Write the new import-path test** + +Add to the bottom of `tests/test_extension_types/test_registry.py` (after the existing `default_logical_type_registry` tests): + +```python +def test_default_registry_accessible_from_registry_module(): + """default_logical_type_registry imported from registry module is same object as from package.""" + from orcapod.extension_types.registry import default_logical_type_registry as from_registry + from orcapod.extension_types import default_logical_type_registry as from_package + assert from_registry is from_package +``` + +- [ ] **Step 2: Run test to verify it fails** + +``` +uv run pytest tests/test_extension_types/test_registry.py::test_default_registry_accessible_from_registry_module -v +``` + +Expected: FAIL — `ImportError: cannot import name 'default_logical_type_registry' from 'orcapod.extension_types.registry'` + +- [ ] **Step 3: Add singleton to the bottom of registry.py** + +Open `src/orcapod/extension_types/registry.py`. Append after the `LogicalTypeRegistry` class: + +```python +# Module-level singleton — per-process registry used by database_hooks and +# application code. Defined here (not in __init__.py) to avoid the circular +# import that would arise if database_hooks imported from the package __init__. +default_logical_type_registry = LogicalTypeRegistry() +``` + +- [ ] **Step 4: Update __init__.py to import singleton from registry** + +Open `src/orcapod/extension_types/__init__.py`. The current content is: + +```python +"""Arrow/Polars extension type system for orcapod. +... +""" + +from __future__ import annotations + +from .protocols import LogicalType +from .registry import LogicalTypeRegistry, make_arrow_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +default_logical_type_registry = LogicalTypeRegistry() + +__all__ = [ + "LogicalType", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +Replace with: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that map +between Python objects and their Arrow/Polars extension type representation. + +The module-level ``default_logical_type_registry`` instance is the process default. +Built-in registrations (``Path``, ``UPath``, ``UUID``) are added by PLT-1656. +``DataContext`` wiring is added by PLT-1660. +""" + +from __future__ import annotations + +from .protocols import LogicalType, LogicalTypeFactory +from .registry import LogicalTypeRegistry, make_arrow_extension_type, default_logical_type_registry +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +__all__ = [ + "LogicalType", + "LogicalTypeFactory", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +- [ ] **Step 5: Run all extension_types tests to verify no regressions** + +``` +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all existing tests PASS including the new `test_default_registry_accessible_from_registry_module` + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/registry.py src/orcapod/extension_types/__init__.py tests/test_extension_types/test_registry.py +git commit -m "refactor(extension_types): move default_logical_type_registry singleton to registry.py" +``` + +--- + +## Task 3: `_factories` dict + `register_logical_type_factory` method + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Test: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Write the failing tests** + +Add to `tests/test_extension_types/test_registry.py` — after the `_make_stub` helper: + +```python +def _make_stub_factory(return_lt: LogicalType | None = None) -> LogicalTypeFactory: + """Factory for minimal LogicalTypeFactory conforming stubs. + + If ``return_lt`` is given, ``create_logical_type`` returns it; otherwise + it creates a fresh stub using ``_make_stub`` keyed on the arrow name. + ``calls`` records every invocation as ``(arrow_extension_name, storage_type, metadata)``. + """ + from orcapod.extension_types.protocols import LogicalTypeFactory + _return_lt = return_lt + + class _Factory: + def __init__(self): + self.calls: list[tuple] = [] + + def create_logical_type(self, arrow_extension_name, storage_type, metadata): + self.calls.append((arrow_extension_name, storage_type, metadata)) + if _return_lt is not None: + return _return_lt + return _make_stub(arrow_name=arrow_extension_name, storage=storage_type) + + return _Factory() +``` + +Then add these tests (before the `# end-to-end integration tests` section): + +```python +# --------------------------------------------------------------------------- +# register_logical_type_factory tests +# --------------------------------------------------------------------------- + +def test_register_logical_type_factory_no_error(): + """register_logical_type_factory completes without raising.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory("TestCat", factory) # should not raise + + +def test_register_logical_type_factory_same_instance_idempotent(): + """Re-registering the same factory instance for the same category does not raise.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory("Cat", factory) + registry.register_logical_type_factory("Cat", factory) # should not raise + + +def test_register_duplicate_category_raises(): + """Registering a different factory for an already-registered category raises ValueError.""" + registry = LogicalTypeRegistry() + f1 = _make_stub_factory() + f2 = _make_stub_factory() + registry.register_logical_type_factory("Cat", f1) + with pytest.raises(ValueError, match="Cat"): + registry.register_logical_type_factory("Cat", f2) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +``` +uv run pytest tests/test_extension_types/test_registry.py -v -k "factory" --no-header 2>&1 | tail -20 +``` + +Expected: FAIL — `AttributeError: 'LogicalTypeRegistry' object has no attribute 'register_logical_type_factory'` + +- [ ] **Step 3: Add `_factories` dict and `register_logical_type_factory` to LogicalTypeRegistry** + +In `src/orcapod/extension_types/registry.py`, inside the `LogicalTypeRegistry` class, update `__init__`: + +```python + def __init__(self) -> None: + self._by_logical_name: dict[str, LogicalType] = {} + self._by_arrow_name: dict[str, LogicalType] = {} + self._by_python_type: dict[type, LogicalType] = {} + self._factories: dict[str, LogicalTypeFactory] = {} +``` + +Then add the new method after `get_by_arrow_extension_name`: + +```python + def register_logical_type_factory( + self, + category: str, + factory: LogicalTypeFactory, + ) -> None: + """Register a factory for the given metadata category string. + + When ``prepare_extension_type`` encounters an Arrow extension type whose + ``extension_metadata`` JSON contains ``{"category": "", ...}``, + it calls ``factory.create_logical_type(arrow_extension_name, storage_type, + metadata_dict)`` to construct the logical type and then registers it. + + Args: + category: The ``"category"`` value from the extension metadata JSON that + identifies this category (e.g. ``"Dataclass"``). + factory: A ``LogicalTypeFactory`` instance responsible for constructing + logical types for this category. + + Raises: + ValueError: If ``category`` is already registered to a different factory. + """ + existing = self._factories.get(category) + if existing is not None and existing is not factory: + raise ValueError( + f"Cannot register factory for category {category!r}: " + f"a different factory is already registered for this category." + ) + if existing is factory: + return + self._factories[category] = factory + logger.debug( + "registered LogicalTypeFactory for category %r: %r", category, factory + ) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +``` +uv run pytest tests/test_extension_types/test_registry.py -v -k "factory" --no-header +``` + +Expected: the 3 new factory tests PASS + +- [ ] **Step 5: Run full extension_types test suite to check for regressions** + +``` +uv run pytest tests/test_extension_types/ -v --no-header 2>&1 | tail -10 +``` + +Expected: all tests PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/registry.py tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add _factories dict and register_logical_type_factory to LogicalTypeRegistry" +``` + +--- + +## Task 4: `prepare_extension_type` — full implementation (all 7 steps) + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Test: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Write ALL failing tests (happy path + error paths)** + +Add to `tests/test_extension_types/test_registry.py`, after the `register_logical_type_factory` tests. Note the `import json` needed at the top of the file: + +First add `import json` to the existing import block at the top of test_registry.py (after `import uuid`). + +Then add these tests: + +```python +# --------------------------------------------------------------------------- +# prepare_extension_type tests +# --------------------------------------------------------------------------- + +def test_register_logical_type_factory_dispatches_on_prepare(): + """prepare_extension_type dispatches to the registered factory and registers the result.""" + import json + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory("TestCat", factory) + + arrow_name = _unique_name() + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + registry.prepare_extension_type(arrow_name, metadata_bytes, pa.large_utf8()) + + assert len(factory.calls) == 1 + assert factory.calls[0][0] == arrow_name + assert registry.get_by_arrow_extension_name(arrow_name) is not None + + +def test_factory_receives_full_metadata_dict(): + """The factory's create_logical_type receives the full parsed JSON dict, not just category.""" + import json + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory("TestCat", factory) + + arrow_name = _unique_name() + metadata_bytes = json.dumps( + {"category": "TestCat", "protocol": 5, "version": "1.0"} + ).encode() + registry.prepare_extension_type(arrow_name, metadata_bytes, pa.large_utf8()) + + assert len(factory.calls) == 1 + _, _, received_metadata = factory.calls[0] + assert received_metadata == {"category": "TestCat", "protocol": 5, "version": "1.0"} + + +def test_prepare_already_registered_noop(): + """prepare_extension_type called twice does not raise and does not call the factory again.""" + import json + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory("TestCat", factory) + + arrow_name = _unique_name() + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + + registry.prepare_extension_type(arrow_name, metadata_bytes, pa.large_utf8()) + registry.prepare_extension_type(arrow_name, metadata_bytes, pa.large_utf8()) # second call + + assert len(factory.calls) == 1 # factory called exactly once + + +def test_prepare_already_registered_none_metadata_noop(): + """Type pre-registered via register(); None metadata on prepare call is a silent no-op.""" + registry = LogicalTypeRegistry() + lt = _make_stub() + registry.register(lt) + + arrow_name = lt.get_arrow_extension_type().extension_name + registry.prepare_extension_type(arrow_name, None, pa.large_utf8()) # should not raise + + +def test_prepare_none_metadata_not_registered_raises(): + """None metadata for an unregistered extension type raises ValueError.""" + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + + with pytest.raises(ValueError, match="must be pre-registered explicitly"): + registry.prepare_extension_type(arrow_name, None, pa.large_utf8()) + + +def test_prepare_invalid_json_raises(): + """Non-UTF-8-JSON extension_metadata raises ValueError with raw bytes and parse error.""" + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + bad_metadata = b"not-json!" + + with pytest.raises(ValueError, match="not valid UTF-8 JSON"): + registry.prepare_extension_type(arrow_name, bad_metadata, pa.large_utf8()) + + +def test_prepare_json_missing_category_raises(): + """Valid JSON metadata without a 'category' key raises ValueError.""" + import json + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + no_category = json.dumps({"version": 1}).encode() + + with pytest.raises(ValueError, match='"category"'): + registry.prepare_extension_type(arrow_name, no_category, pa.large_utf8()) + + +def test_prepare_unknown_category_raises(): + """Valid JSON with 'category' but no matching factory raises ValueError.""" + import json + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + unknown = json.dumps({"category": "NoSuchFactory"}).encode() + + with pytest.raises(ValueError, match="NoSuchFactory"): + registry.prepare_extension_type(arrow_name, unknown, pa.large_utf8()) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +``` +uv run pytest tests/test_extension_types/test_registry.py -v -k "prepare" --no-header 2>&1 | tail -20 +``` + +Expected: FAIL — `AttributeError: 'LogicalTypeRegistry' object has no attribute 'prepare_extension_type'` + +- [ ] **Step 3: Implement `prepare_extension_type` in registry.py** + +Add this method to `LogicalTypeRegistry` (after `register_logical_type_factory`): + +```python + def prepare_extension_type( + self, + arrow_extension_name: str, + extension_metadata: bytes | None, + storage_type: pa.DataType, + ) -> None: + """Ensure the Arrow extension type identified by ``arrow_extension_name`` + is registered as a ``LogicalType``. + + This is the single entry point called by ``ensure_extensions_registered`` + in ``database_hooks``. The registry owns all dispatch logic. + + Args: + arrow_extension_name: Arrow extension type name (``ARROW:extension:name``). + extension_metadata: Raw metadata bytes (``ARROW:extension:metadata``), + expected to be UTF-8 JSON containing at least a ``"category"`` key. + ``None`` if absent. + storage_type: Underlying Arrow storage type for this extension field. + + Raises: + ValueError: If ``extension_metadata`` is ``None`` and the type is not + already registered. + ValueError: If ``extension_metadata`` is not valid UTF-8 JSON. + ValueError: If the parsed JSON has no ``"category"`` key. + ValueError: If no factory is registered for the ``"category"`` value. + ValueError: Propagated from the factory if it cannot construct a type. + """ + # Step 1: per-process cache hit — no-op regardless of metadata content. + if self.get_by_arrow_extension_name(arrow_extension_name) is not None: + logger.debug( + "prepare_extension_type: %r already registered, skipping", + arrow_extension_name, + ) + return + + # Step 2: None metadata — cannot auto-register; must be pre-registered. + if extension_metadata is None: + raise ValueError( + f"Extension type {arrow_extension_name!r} has no extension metadata " + f"(metadata is None).\n" + f"Types without a metadata category tag cannot be auto-registered via " + f"a factory — they must be pre-registered explicitly via " + f"default_logical_type_registry.register(logical_type)." + ) + + # Step 3: Parse JSON. + try: + metadata_dict = json.loads(extension_metadata.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise ValueError( + f"Extension type {arrow_extension_name!r} has extension metadata that " + f"is not valid UTF-8 JSON: {extension_metadata!r}. " + f"Parse error: {exc}.\n" + f'Extension metadata must be a JSON object with at least a "category" ' + f'key, e.g. {{"category": "Dataclass"}}.' + ) from exc + + # Step 4: Require "category" key. + if "category" not in metadata_dict: + raise ValueError( + f"Extension type {arrow_extension_name!r} has extension metadata JSON " + f'with no "category" key: {metadata_dict}. Extension metadata must be ' + f'a JSON object with at least a "category" key, e.g. ' + f'{{"category": "Dataclass"}}.' + ) + + category = metadata_dict["category"] + + # Step 5: Look up factory. + factory = self._factories.get(category) + if factory is None: + raise ValueError( + f"No LogicalTypeFactory is registered for category {category!r}.\n" + f"Cannot prepare extension type {arrow_extension_name!r} for " + f"registration.\n" + f"Register a factory via " + f"default_logical_type_registry.register_logical_type_factory(\n" + f" {category!r}, factory\n" + f")." + ) + + # Step 6: Construct logical type via factory. + logger.debug( + "prepare_extension_type: %r not registered — dispatching to category %r factory", + arrow_extension_name, + category, + ) + logical_type = factory.create_logical_type( + arrow_extension_name, storage_type, metadata_dict + ) + + # Step 7: Register in all three bindings + PA/Polars global registries. + self.register(logical_type) + logger.debug( + "prepare_extension_type: successfully registered %r via %r factory", + arrow_extension_name, + category, + ) +``` + +- [ ] **Step 4: Run tests to verify they all pass** + +``` +uv run pytest tests/test_extension_types/test_registry.py -v --no-header 2>&1 | tail -15 +``` + +Expected: all tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/registry.py tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add prepare_extension_type to LogicalTypeRegistry" +``` + +--- + +## Task 5: `database_hooks.py` module + `__init__.py` exports + test suite + +**Files:** +- Create: `src/orcapod/extension_types/database_hooks.py` +- Modify: `src/orcapod/extension_types/__init__.py` +- Create: `tests/test_extension_types/test_database_hooks.py` + +- [ ] **Step 1: Write the failing test file** + +Create `tests/test_extension_types/test_database_hooks.py`: + +```python +"""Tests for ensure_extensions_registered in database_hooks.""" + +from __future__ import annotations + +import json +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + """Unique Arrow extension name to avoid cross-test global-registry collisions.""" + return f"test.hooks.{uuid.uuid4().hex[:8]}" + + +def _make_ext_schema( + arrow_name: str, + metadata: bytes | None = None, + storage: pa.DataType | None = None, +) -> pa.Schema: + """Build a ``pa.Schema`` with one extension-typed field using ``make_arrow_extension_type``. + + Only call this when you have control over the metadata content — the resulting + field's type is an in-memory ``pa.ExtensionType`` instance, not raw field metadata. + """ + _storage = storage or pa.large_utf8() + ext_cls = make_arrow_extension_type(arrow_name, _storage, metadata=metadata) + return pa.schema([pa.field("col", ext_cls())]) + + +def _make_field_metadata_schema( + arrow_name: str, + metadata: bytes, + storage: pa.DataType | None = None, +) -> pa.Schema: + """Build a schema where the extension is described by raw Arrow field metadata. + + This simulates a Parquet/IPC read where the extension type was not registered + in the current process, so ``field.type`` is a plain Arrow storage type rather + than a ``pa.ExtensionType`` instance. + """ + _storage = storage or pa.large_utf8() + field = pa.field("col", _storage).with_metadata({ + b"ARROW:extension:name": arrow_name.encode(), + b"ARROW:extension:metadata": metadata, + }) + return pa.schema([field]) + + +def _make_stub_factory(registry: LogicalTypeRegistry): + """Return a minimal LogicalTypeFactory stub whose calls are recorded. + + The factory auto-creates a fresh ``LogicalType`` stub keyed by arrow name. + Registering this factory in *registry* causes it to also register a Polars + extension type, which requires the Arrow ext type to be in PyArrow's global + registry. To avoid cross-test collisions, each test uses a unique arrow name. + """ + class _Factory: + def __init__(self): + self.calls: list[tuple] = [] + + def create_logical_type(self, arrow_extension_name, storage_type, metadata): + import polars as pl + from orcapod.extension_types.registry import make_arrow_extension_type + + self.calls.append((arrow_extension_name, storage_type, metadata)) + + _name = arrow_extension_name + _arrow_cls = make_arrow_extension_type(_name, storage_type) + _pl_storage = pl.from_arrow(pa.array([], type=storage_type)).dtype + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(_name, _pl_storage, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _StubLT: + @property + def logical_type_name(self): + return _name + @property + def python_type(self): + return str + def get_arrow_extension_type(self): + return _arrow_cls() + def get_polars_extension_type(self): + return _PolarsExt() + def python_to_storage(self, value): + return str(value) + def storage_to_python(self, storage_value): + return storage_value + + return _StubLT() + + return _Factory() + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + +@pytest.fixture +def fresh_registry(monkeypatch): + """A fresh LogicalTypeRegistry monkeypatched into database_hooks module.""" + import orcapod.extension_types.database_hooks as hooks + registry = LogicalTypeRegistry() + monkeypatch.setattr(hooks, "default_logical_type_registry", registry) + return registry + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_no_extension_types_is_noop(fresh_registry): + """Schema with only primitives — ensure_extensions_registered returns without touching registry.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + schema = pa.schema([ + pa.field("id", pa.int64()), + pa.field("name", pa.large_utf8()), + ]) + ensure_extensions_registered(schema) + # fresh_registry is empty — no error means no spurious lookup was triggered + assert fresh_registry.get_by_arrow_extension_name("anything") is None + + +def test_known_type_is_registered(fresh_registry): + """Schema with one extension type whose factory is registered — type is registered after call.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + factory = _make_stub_factory(fresh_registry) + fresh_registry.register_logical_type_factory("TestCat", factory) + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema = _make_ext_schema(arrow_name, metadata=metadata_bytes) + + ensure_extensions_registered(schema) + + assert fresh_registry.get_by_arrow_extension_name(arrow_name) is not None + assert len(factory.calls) == 1 + + +def test_already_registered_is_skipped(fresh_registry): + """Calling ensure_extensions_registered twice does not raise and factory is called once.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + factory = _make_stub_factory(fresh_registry) + fresh_registry.register_logical_type_factory("TestCat", factory) + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema = _make_ext_schema(arrow_name, metadata=metadata_bytes) + + ensure_extensions_registered(schema) + ensure_extensions_registered(schema) # second call + + assert len(factory.calls) == 1 # factory invoked exactly once + + +def test_none_metadata_already_registered_noop(fresh_registry): + """Extension type with None metadata that IS already in the registry — silent no-op.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + factory = _make_stub_factory(fresh_registry) + fresh_registry.register_logical_type_factory("TestCat", factory) + + # First: register via metadata so it ends up in the registry. + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema_with_meta = _make_ext_schema(arrow_name, metadata=metadata_bytes) + ensure_extensions_registered(schema_with_meta) + + # Now: same arrow name but with no metadata (simulates reading the schema without + # metadata — e.g. after an IPC round-trip where the type is now registered in-process). + schema_no_meta = _make_ext_schema(arrow_name, metadata=None) # metadata=None → b"" + ensure_extensions_registered(schema_no_meta) # should NOT raise + + +def test_none_metadata_not_registered_raises(fresh_registry): + """Unregistered extension type with None metadata raises ValueError.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + schema = _make_ext_schema(arrow_name, metadata=None) # metadata=None → b"" → walker normalizes to None + + with pytest.raises(ValueError, match="must be pre-registered explicitly"): + ensure_extensions_registered(schema) + + +def test_metadata_not_json_raises(fresh_registry): + """Unregistered extension type with non-JSON metadata bytes raises ValueError.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + schema = _make_field_metadata_schema(arrow_name, metadata=b"not-json!") + + with pytest.raises(ValueError, match="not valid UTF-8 JSON"): + ensure_extensions_registered(schema) + + +def test_metadata_json_missing_category_raises(fresh_registry): + """Unregistered extension type with valid JSON but no 'category' key raises ValueError.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + schema = _make_field_metadata_schema( + arrow_name, metadata=json.dumps({"version": 1}).encode() + ) + + with pytest.raises(ValueError, match='"category"'): + ensure_extensions_registered(schema) + + +def test_unknown_metadata_raises(fresh_registry): + """Unregistered extension type with valid JSON and 'category' but no matching factory raises ValueError.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + schema = _make_field_metadata_schema( + arrow_name, metadata=json.dumps({"category": "NoSuchFactory"}).encode() + ) + + with pytest.raises(ValueError, match="NoSuchFactory"): + ensure_extensions_registered(schema) + + +def test_nested_extension_type(fresh_registry): + """Extension type inside a struct column is discovered and registered.""" + from orcapod.extension_types.database_hooks import ensure_extensions_registered + + arrow_name = _unique_name() + factory = _make_stub_factory(fresh_registry) + fresh_registry.register_logical_type_factory("TestCat", factory) + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + inner_ext_cls = make_arrow_extension_type(arrow_name, pa.large_utf8(), metadata=metadata_bytes) + + struct_type = pa.struct([pa.field("inner", inner_ext_cls())]) + schema = pa.schema([pa.field("outer", struct_type)]) + + ensure_extensions_registered(schema) + + assert fresh_registry.get_by_arrow_extension_name(arrow_name) is not None + assert len(factory.calls) == 1 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +``` +uv run pytest tests/test_extension_types/test_database_hooks.py -v --no-header 2>&1 | tail -15 +``` + +Expected: FAIL — `ModuleNotFoundError: No module named 'orcapod.extension_types.database_hooks'` + +- [ ] **Step 3: Create `database_hooks.py`** + +Create `src/orcapod/extension_types/database_hooks.py`: + +```python +"""Peek-schema hook for extension type auto-registration at database read time. + +Call ``ensure_extensions_registered(schema)`` before returning any Arrow table +from a database read path. It is a no-op when the schema contains no extension +types. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from orcapod.extension_types.registry import default_logical_type_registry +from orcapod.extension_types.schema_walker import walk_schema + +if TYPE_CHECKING: + import pyarrow as pa + +logger = logging.getLogger(__name__) + + +def ensure_extensions_registered(schema: pa.Schema) -> None: + """Register any extension types found in ``schema`` that are not yet known. + + Walks ``schema`` recursively to discover all Arrow extension types at any + nesting depth. For each discovered type, delegates to + ``default_logical_type_registry.prepare_extension_type``. + + Already-registered types are detected and skipped inside the registry — + this function itself is stateless. + + Args: + schema: The Arrow schema to inspect. May contain no extension types, + in which case this call is a no-op. + + Raises: + ValueError: Propagated from the registry if an extension type's metadata + has no registered factory or is malformed. + """ + found = walk_schema(schema) + if not found: + logger.debug("ensure_extensions_registered: no extension types in schema") + return + logger.debug( + "ensure_extensions_registered: found %d extension type(s) in schema: %s", + len(found), + [info.extension_name for info in found], + ) + for info in found: + default_logical_type_registry.prepare_extension_type( + info.extension_name, + info.extension_metadata, + info.storage_type, + ) +``` + +- [ ] **Step 4: Add `ensure_extensions_registered` to `__init__.py` exports** + +In `src/orcapod/extension_types/__init__.py`, add the import and export: + +```python +from .database_hooks import ensure_extensions_registered +``` + +Add `"ensure_extensions_registered"` to `__all__`. + +The final `__init__.py` should look like: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that map +between Python objects and their Arrow/Polars extension type representation. + +The module-level ``default_logical_type_registry`` instance is the process default. +Built-in registrations (``Path``, ``UPath``, ``UUID``) are added by PLT-1656. +``DataContext`` wiring is added by PLT-1660. +""" + +from __future__ import annotations + +from .protocols import LogicalType, LogicalTypeFactory +from .registry import LogicalTypeRegistry, make_arrow_extension_type, default_logical_type_registry +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema +from .database_hooks import ensure_extensions_registered + +__all__ = [ + "LogicalType", + "LogicalTypeFactory", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", + # PLT-1655 + "ensure_extensions_registered", +] +``` + +- [ ] **Step 5: Run all tests to verify they pass** + +``` +uv run pytest tests/test_extension_types/ -v --no-header 2>&1 | tail -20 +``` + +Expected: all tests PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/database_hooks.py src/orcapod/extension_types/__init__.py tests/test_extension_types/test_database_hooks.py +git commit -m "feat(extension_types): add database_hooks.ensure_extensions_registered and update exports" +``` + +--- + +## Task 6: Hook `DeltaTableDatabase._read_delta_table` + +**Context:** `delta_lake_databases.py` already has `import logging` and `logger = logging.getLogger(__name__)`. Only a new import and a single hook call are needed. + +**Files:** +- Modify: `src/orcapod/databases/delta_lake_databases.py` + +- [ ] **Step 1: Add the import for `ensure_extensions_registered`** + +In `src/orcapod/databases/delta_lake_databases.py`, find the existing imports block. The file starts with: + +```python +from __future__ import annotations + +import logging +from collections import defaultdict +from collections.abc import Collection, Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, cast + +from orcapod.databases.utils import coerce_record_id +from orcapod.databases.storage_utils import is_cloud_uri, parse_base_path +from orcapod.utils import arrow_utils +from orcapod.utils.lazy_module import LazyModule +``` + +Add the new import after the existing `orcapod` imports: + +```python +from orcapod.extension_types.database_hooks import ensure_extensions_registered +``` + +- [ ] **Step 2: Add the hook call in `_read_delta_table`** + +Find `_read_delta_table` (around line 818). The current code after the method docstring is: + +```python + filter_expr = None + # Use to_pyarrow_dataset with as_large_types for Polars compatible arrow table loading + dataset = delta_table.to_pyarrow_dataset(as_large_types=True) + if filters and expression is None: +``` + +Replace with (adding 2 lines after the dataset assignment): + +```python + filter_expr = None + # Use to_pyarrow_dataset with as_large_types for Polars compatible arrow table loading + dataset = delta_table.to_pyarrow_dataset(as_large_types=True) + logger.debug("_read_delta_table: peeking schema for extension type registration") + ensure_extensions_registered(delta_table.schema().to_arrow()) + if filters and expression is None: +``` + +- [ ] **Step 3: Run the full test suite** + +``` +uv run pytest tests/ -v --no-header -q 2>&1 | tail -20 +``` + +Expected: all tests PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/databases/delta_lake_databases.py +git commit -m "feat(databases): call ensure_extensions_registered in DeltaTableDatabase._read_delta_table" +``` + +--- + +## Task 7: Hook `ConnectorArrowDatabase._get_committed_table` + +**Context:** `connector_arrow_database.py` currently has no `logger`. Add it alongside the hook import. + +**Files:** +- Modify: `src/orcapod/databases/connector_arrow_database.py` + +- [ ] **Step 1: Add `import logging`, `logger`, and hook import** + +In `src/orcapod/databases/connector_arrow_database.py`, the current imports block begins: + +```python +from __future__ import annotations + +import re +from collections import defaultdict +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any, cast + +from orcapod.databases.utils import coerce_record_id +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.utils.lazy_module import LazyModule +``` + +Replace with: + +```python +from __future__ import annotations + +import logging +import re +from collections import defaultdict +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any, cast + +from orcapod.databases.utils import coerce_record_id +from orcapod.extension_types.database_hooks import ensure_extensions_registered +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.utils.lazy_module import LazyModule + +logger = logging.getLogger(__name__) +``` + +- [ ] **Step 2: Add the hook call in `_get_committed_table`** + +Find `_get_committed_table` (around line 176). The current implementation is: + +```python + def _get_committed_table( + self, record_path: tuple[str, ...] + ) -> pa.Table | None: + """Fetch all committed records for a path from the connector.""" + table_name = self._path_to_table_name(self._path_prefix + record_path) + if table_name not in self._connector.get_table_names(): + return None + batches = list( + self._connector.iter_batches(f'SELECT * FROM "{table_name}"') + ) + if not batches: + return None + return pa.Table.from_batches(batches) +``` + +Replace with: + +```python + def _get_committed_table( + self, record_path: tuple[str, ...] + ) -> pa.Table | None: + """Fetch all committed records for a path from the connector.""" + table_name = self._path_to_table_name(self._path_prefix + record_path) + if table_name not in self._connector.get_table_names(): + return None + batches = list( + self._connector.iter_batches(f'SELECT * FROM "{table_name}"') + ) + if not batches: + return None + logger.debug("_get_committed_table: peeking schema for extension type registration") + ensure_extensions_registered(batches[0].schema) + return pa.Table.from_batches(batches) +``` + +- [ ] **Step 3: Run the full test suite** + +``` +uv run pytest tests/ -v --no-header -q 2>&1 | tail -20 +``` + +Expected: all tests PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/databases/connector_arrow_database.py +git commit -m "feat(databases): add logger and ensure_extensions_registered hook to ConnectorArrowDatabase._get_committed_table" +``` + +--- + +## Final Verification + +- [ ] **Run the complete test suite one final time** + +``` +uv run pytest tests/ -q --no-header 2>&1 | tail -5 +``` + +Expected: all tests PASS, no warnings about new code + +- [ ] **Create PR targeting `extension-type-system` branch** + +```bash +gh pr create \ + --base extension-type-system \ + --title "feat(PLT-1655): add peek-schema → register → read pattern with per-process cache" \ + --body "$(cat <<'EOF' +## Summary + +* Adds `LogicalTypeFactory` Protocol — a pure factory that constructs a `LogicalType` from an Arrow extension name, storage type, and full parsed JSON metadata dict. +* Adds `register_logical_type_factory(category, factory)` and `prepare_extension_type(arrow_extension_name, metadata, storage_type)` to `LogicalTypeRegistry`. The registry's `_by_arrow_name` dict acts as the per-process cache (step 1: already-registered → immediate no-op regardless of metadata). +* Adds stateless `ensure_extensions_registered(schema)` in `extension_types/database_hooks.py`. Walks the schema, delegates each extension type to `prepare_extension_type`. +* Wires the hook into `DeltaTableDatabase._read_delta_table` (schema peek via `DeltaTable.schema().to_arrow()`) and `ConnectorArrowDatabase._get_committed_table` (schema peek via `batches[0].schema`). +* Moves `default_logical_type_registry` singleton from `__init__.py` to `registry.py` to break the circular import that would arise with `database_hooks`. +* Sufficient DEBUG-level logging throughout: discovery, cache hit, factory dispatch, successful registration. + +## Test plan + +- [ ] `uv run pytest tests/test_extension_types/ -v` — all new unit tests pass +- [ ] `uv run pytest tests/ -q` — no regressions + +Fixes PLT-1655 +EOF +)" +``` diff --git a/superpowers/plans/2026-06-14-plt-1656-builtin-logical-types.md b/superpowers/plans/2026-06-14-plt-1656-builtin-logical-types.md new file mode 100644 index 00000000..09c01167 --- /dev/null +++ b/superpowers/plans/2026-06-14-plt-1656-builtin-logical-types.md @@ -0,0 +1,1511 @@ +# PLT-1656: Built-in LogicalType Implementations (Path, UPath, UUID) — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement three built-in `LogicalType` classes (`LogicalPath`, `LogicalUPath`, `LogicalUUID`), wire them into `DataContext` via `v0.1.json`, and expose a `get_default_logical_type_registry()` convenience accessor. + +**Architecture:** Each `LogicalType` owns its Arrow/Polars extension type instances via class-level caching. A new `make_polars_extension_type` helper (parallel to the existing `make_arrow_extension_type`) synthesises `pl.BaseExtension` subclasses at runtime. The registry is populated via the existing `parse_objectspec` JSON object spec mechanism so `LogicalTypeRegistry` gains a `logical_types` constructor param. The module-level `default_logical_type_registry` in `extension_types/__init__.py` is removed — the canonical access path becomes `get_default_context().logical_type_registry`. + +**Tech Stack:** Python 3.12+, PyArrow ≥ 20, Polars ≥ 1.36.0, pytest, uv. + +--- + +## File Map + +| File | Action | Responsibility | +|---|---|---| +| `src/orcapod/extension_types/registry.py` | Modify | Add `make_polars_extension_type` helper; add `logical_types` param to `LogicalTypeRegistry.__init__` | +| `src/orcapod/extension_types/__init__.py` | Modify | Export `make_polars_extension_type`; remove `default_logical_type_registry` | +| `src/orcapod/extension_types/builtin_logical_types.py` | **New** | `LogicalPath`, `LogicalUPath`, `LogicalUUID` implementations | +| `src/orcapod/contexts/core.py` | Modify | Add `logical_type_registry: LogicalTypeRegistry` field to `DataContext` | +| `src/orcapod/contexts/registry.py` | Modify | Add `"logical_type_registry"` to required fields; pass it through in `_create_context_from_spec` | +| `src/orcapod/contexts/data/v0.1.json` | Modify | Add `logical_type_registry` object spec entry | +| `src/orcapod/contexts/data/schemas/context_schema.json` | Modify | Add `logical_type_registry` to `required` and `properties` | +| `src/orcapod/contexts/__init__.py` | Modify | Add `get_default_logical_type_registry()` convenience function | +| `tests/test_extension_types/test_registry.py` | Modify | Add tests for `make_polars_extension_type` and `logical_types` param; remove stale `default_logical_type_registry` tests | +| `tests/test_extension_types/test_builtin_logical_types.py` | **New** | Protocol conformance, property values, round-trips, default-context integration tests | + +--- + +### Task 1: `make_polars_extension_type` helper + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Modify: `src/orcapod/extension_types/__init__.py` +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Write the failing tests** + +Add these tests at the end of `tests/test_extension_types/test_registry.py`, before the `# default_logical_type_registry tests` section: + +```python +# --------------------------------------------------------------------------- +# make_polars_extension_type tests +# --------------------------------------------------------------------------- + +from orcapod.extension_types.registry import make_polars_extension_type + + +def test_make_polars_extension_type_returns_class(): + """make_polars_extension_type returns a pl.BaseExtension subclass.""" + cls = make_polars_extension_type("test.MakePolarsExt", pa.large_utf8()) + assert issubclass(cls, pl.BaseExtension) + + +def test_make_polars_extension_type_instance_has_correct_name(): + """Instantiating the returned class yields the correct ext_name.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8()) + inst = cls() + assert inst.ext_name() == name + + +def test_make_polars_extension_type_ext_from_params_returns_instance(): + """ext_from_params classmethod returns an instance of the class.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8()) + inst = cls.ext_from_params(name, pl.String, None) + assert isinstance(inst, cls) + + +def test_make_polars_extension_type_with_binary_storage(): + """make_polars_extension_type works with pa.binary(16) storage (UUID case).""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.binary(16), None) + inst = cls() + assert inst.ext_name() == name + + +def test_make_polars_extension_type_with_metadata(): + """make_polars_extension_type captures metadata in the class.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8(), "test.metadata") + # Instantiating should not raise; ext_name is correct. + inst = cls() + assert inst.ext_name() == name +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +cd /home/kurouto/kurouto-jobs/fccdf92d-a25e-4477-ae00-a1ee2b6dc236/orcapod-python +uv run pytest tests/test_extension_types/test_registry.py::test_make_polars_extension_type_returns_class -v +``` + +Expected: `ImportError` — `make_polars_extension_type` does not exist yet. + +- [ ] **Step 3: Implement `make_polars_extension_type` in `registry.py`** + +Add after `make_arrow_extension_type` (around line 98), before the `LogicalTypeRegistry` class: + +```python +def make_polars_extension_type( + extension_name: str, + arrow_storage_type: pa.DataType, + metadata: str | None = None, +) -> type[pl.BaseExtension]: + """Synthesise and return a ``pl.BaseExtension`` subclass. + + Derives the Polars storage dtype from *arrow_storage_type* via + ``pl.from_arrow``. Returns the *class*; callers instantiate it inside + ``get_polars_extension_type()``. + + The returned class uses the Arrow extension name as its registration name + (the same name passed to ``pl.register_extension_type``), so that Polars + correctly maps Arrow extension columns on read. + + Args: + extension_name: The extension type name used for Polars registration. + Must match the Arrow extension name so Polars can round-trip the + type through Arrow IPC. + arrow_storage_type: The Arrow storage type. Converted once to the + corresponding Polars dtype via ``pl.from_arrow``. + metadata: Optional metadata string stored as ``metadata_str`` in the + Polars extension. Defaults to ``None``. + + Returns: + A ``pl.BaseExtension`` subclass. Call it with no arguments to obtain + an instance suitable for passing to ``pl.register_extension_type`` or + returning from ``get_polars_extension_type()``. + """ + _name = extension_name + _polars_dtype = pl.from_arrow(pa.array([], type=arrow_storage_type)).dtype + _metadata = metadata + + def __init__(self: pl.BaseExtension) -> None: + pl.BaseExtension.__init__(self, _name, _polars_dtype, _metadata) + + @classmethod # type: ignore[misc] + def ext_from_params( + cls: type[pl.BaseExtension], + ext_name: str, + storage_dtype: pl.PolarsDataType, + metadata_str: str | None, + ) -> pl.BaseExtension: + return cls() + + return type( + f"_PolarsExt_{_sanitize(extension_name)}", + (pl.BaseExtension,), + { + "__init__": __init__, + "ext_from_params": ext_from_params, + }, + ) +``` + +- [ ] **Step 4: Export `make_polars_extension_type` from `extension_types/__init__.py`** + +In `src/orcapod/extension_types/__init__.py`, update the import line and `__all__`: + +```python +from .registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type +``` + +And add `"make_polars_extension_type"` to `__all__`: + +```python +__all__ = [ + "LogicalType", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "make_polars_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +- [ ] **Step 5: Run tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -k "polars_extension_type" -v +``` + +Expected: All 5 new tests PASS. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/registry.py \ + src/orcapod/extension_types/__init__.py \ + tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add make_polars_extension_type helper" +``` + +--- + +### Task 2: `LogicalTypeRegistry` `logical_types` constructor param + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Write the failing tests** + +Add after the existing `test_get_by_arrow_extension_name_miss` test, before the PyArrow global registry tests section: + +```python +# --------------------------------------------------------------------------- +# LogicalTypeRegistry constructor logical_types param tests +# --------------------------------------------------------------------------- + +def test_registry_init_with_logical_types_preregisters(): + """LogicalTypeRegistry(logical_types=[lt]) makes the type immediately retrievable.""" + lt = _make_stub() + registry = LogicalTypeRegistry(logical_types=[lt]) + assert registry.get_by_logical_name(lt.logical_type_name) is lt + assert registry.get_by_python_type(lt.python_type) is lt + assert registry.get_by_arrow_extension_name(lt.get_arrow_extension_type().extension_name) is lt + + +def test_registry_init_with_none_is_empty(): + """LogicalTypeRegistry(logical_types=None) starts empty without error.""" + registry = LogicalTypeRegistry(logical_types=None) + assert registry.get_by_logical_name("anything") is None + + +def test_registry_init_with_empty_list_is_empty(): + """LogicalTypeRegistry(logical_types=[]) starts empty without error.""" + registry = LogicalTypeRegistry(logical_types=[]) + assert registry.get_by_logical_name("anything") is None + + +def test_registry_init_with_multiple_logical_types(): + """LogicalTypeRegistry(logical_types=[lt1, lt2]) registers both.""" + lt1 = _make_stub(py_type=int) + lt2 = _make_stub(py_type=float) + registry = LogicalTypeRegistry(logical_types=[lt1, lt2]) + assert registry.get_by_logical_name(lt1.logical_type_name) is lt1 + assert registry.get_by_logical_name(lt2.logical_type_name) is lt2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_registry.py::test_registry_init_with_logical_types_preregisters -v +``` + +Expected: FAIL — `LogicalTypeRegistry.__init__` does not accept `logical_types` argument. + +- [ ] **Step 3: Update `LogicalTypeRegistry.__init__` in `registry.py`** + +Replace the current `__init__` method (lines 121–124): + +```python +# OLD +def __init__(self) -> None: + self._by_logical_name: dict[str, LogicalType] = {} + self._by_arrow_name: dict[str, LogicalType] = {} + self._by_python_type: dict[type, LogicalType] = {} +``` + +With: + +```python +def __init__(self, logical_types: list[LogicalType] | None = None) -> None: + self._by_logical_name: dict[str, LogicalType] = {} + self._by_arrow_name: dict[str, LogicalType] = {} + self._by_python_type: dict[type, LogicalType] = {} + for lt in (logical_types or []): + self.register(lt) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -k "registry_init" -v +``` + +Expected: All 4 new tests PASS. Also run the full registry suite to confirm no regressions: + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: All tests PASS (the last 6 `default_logical_type_registry` tests still reference the old module-level instance and will continue passing for now — they are removed in Task 6). + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/registry.py \ + tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add logical_types constructor param to LogicalTypeRegistry" +``` + +--- + +### Task 3: `LogicalPath` and `LogicalUPath` implementations + +**Files:** +- Create: `src/orcapod/extension_types/builtin_logical_types.py` +- Create: `tests/test_extension_types/test_builtin_logical_types.py` + +- [ ] **Step 1: Create the test file with failing tests for `LogicalPath` and `LogicalUPath`** + +Create `tests/test_extension_types/test_builtin_logical_types.py`: + +```python +"""Tests for built-in LogicalType implementations (LogicalPath, LogicalUPath, LogicalUUID).""" + +from __future__ import annotations + +import pathlib +import uuid as uuid_module +import warnings + +import polars as pl +import pyarrow as pa +import pytest +from upath import UPath + +from orcapod.extension_types.protocols import LogicalType +from orcapod.extension_types.registry import LogicalTypeRegistry + + +# --------------------------------------------------------------------------- +# LogicalPath tests +# --------------------------------------------------------------------------- + + +def test_logical_path_isinstance_logical_type(): + """LogicalPath() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert isinstance(LogicalPath(), LogicalType) + + +def test_logical_path_logical_type_name(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().logical_type_name == "pathlib.Path" + + +def test_logical_path_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().python_type is pathlib.Path + + +def test_logical_path_arrow_ext_name(): + """get_arrow_extension_type().extension_name is 'pathlib.Path'.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().get_arrow_extension_type().extension_name == "pathlib.Path" + + +def test_logical_path_arrow_ext_storage_type(): + """Arrow extension storage type is pa.large_string().""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().get_arrow_extension_type().storage_type == pa.large_string() + + +def test_logical_path_get_arrow_extension_type_is_cached(): + """get_arrow_extension_type() returns the same object on repeated calls.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_path_get_polars_extension_type_is_cached(): + """get_polars_extension_type() returns the same object on repeated calls.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + assert lt.get_polars_extension_type() is lt.get_polars_extension_type() + + +def test_logical_path_round_trip(): + """Path -> python_to_storage -> storage_to_python -> Path is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + p = pathlib.Path("/tmp/foo/bar.txt") + assert lt.storage_to_python(lt.python_to_storage(p)) == p + + +def test_logical_path_python_to_storage_returns_string(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + result = lt.python_to_storage(pathlib.Path("/tmp/test")) + assert isinstance(result, str) + assert result == "/tmp/test" + + +# --------------------------------------------------------------------------- +# LogicalUPath tests +# --------------------------------------------------------------------------- + + +def test_logical_upath_isinstance_logical_type(): + """LogicalUPath() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert isinstance(LogicalUPath(), LogicalType) + + +def test_logical_upath_logical_type_name(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().logical_type_name == "upath.UPath" + + +def test_logical_upath_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().python_type is UPath + + +def test_logical_upath_arrow_ext_name(): + """get_arrow_extension_type().extension_name is 'upath.UPath'.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().get_arrow_extension_type().extension_name == "upath.UPath" + + +def test_logical_upath_arrow_ext_storage_type(): + """Arrow extension storage type is pa.large_string().""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().get_arrow_extension_type().storage_type == pa.large_string() + + +def test_logical_upath_get_arrow_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_upath_round_trip(): + """UPath -> python_to_storage -> storage_to_python -> UPath is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + up = UPath("s3://bucket/key/file.txt") + assert lt.storage_to_python(lt.python_to_storage(up)) == up + + +def test_logical_upath_python_to_storage_returns_string(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + result = lt.python_to_storage(UPath("s3://bucket/key")) + assert isinstance(result, str) + assert result == "s3://bucket/key" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py::test_logical_path_isinstance_logical_type -v +``` + +Expected: `ModuleNotFoundError` — `builtin_logical_types` does not exist yet. + +- [ ] **Step 3: Create `src/orcapod/extension_types/builtin_logical_types.py` with `LogicalPath` and `LogicalUPath`** + +```python +"""Built-in LogicalType implementations for orcapod. + +Provides three built-in logical types registered into the default +``DataContext.logical_type_registry`` via ``contexts/data/v0.1.json``: + +- ``LogicalPath``: maps ``pathlib.Path`` ↔ Arrow large_string extension "pathlib.Path" +- ``LogicalUPath``: maps ``upath.UPath`` ↔ Arrow large_string extension "upath.UPath" +- ``LogicalUUID``: maps ``uuid.UUID`` ↔ PyArrow built-in ``pa.uuid()`` ("arrow.uuid") + +Note: + All imports from orcapod.extension_types use direct submodule paths + (e.g. ``from orcapod.extension_types.registry import ...``) rather than + the package ``__init__`` to avoid circular imports when the context system + loads this module at startup. +""" + +from __future__ import annotations + +import pathlib +import uuid as _uuid_module +from typing import TYPE_CHECKING, Any + +from upath import UPath + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + + +class LogicalPath: + """Logical type for ``pathlib.Path``. + + Stores paths as Arrow large strings using the custom extension type + ``"pathlib.Path"`` with metadata ``b"orcapod.builtin"``. + + Example: + >>> lt = LogicalPath() + >>> lt.python_to_storage(pathlib.Path("/tmp/foo")) + '/tmp/foo' + >>> lt.storage_to_python('/tmp/foo') + PosixPath('/tmp/foo') + """ + + _arrow_ext_class = make_arrow_extension_type( + "pathlib.Path", pa.large_string(), b"orcapod.builtin" + ) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type( + "pathlib.Path", pa.large_string(), "orcapod.builtin" + ) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "pathlib.Path" + python_type: type = pathlib.Path + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``pathlib.Path``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"pathlib.Path"`` and storage type ``pa.large_string()``. + """ + if LogicalPath._arrow_ext is None: + LogicalPath._arrow_ext = LogicalPath._arrow_ext_class() + return LogicalPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``pathlib.Path``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"pathlib.Path"``. + """ + if LogicalPath._polars_ext is None: + LogicalPath._polars_ext = LogicalPath._polars_ext_class() + return LogicalPath._polars_ext + + def python_to_storage(self, value: Any) -> str: + """Convert a ``pathlib.Path`` to its string representation. + + Args: + value: A ``pathlib.Path`` instance. + + Returns: + The string form of the path (e.g. ``"/tmp/foo"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any) -> pathlib.Path: + """Reconstruct a ``pathlib.Path`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + + Returns: + A ``pathlib.Path`` instance. + """ + return pathlib.Path(storage_value) + + +class LogicalUPath: + """Logical type for ``upath.UPath``. + + Stores paths as Arrow large strings using the custom extension type + ``"upath.UPath"`` with metadata ``b"orcapod.builtin"``. + + Example: + >>> lt = LogicalUPath() + >>> lt.python_to_storage(UPath("s3://bucket/key")) + 's3://bucket/key' + >>> lt.storage_to_python("s3://bucket/key") + UPath('s3://bucket/key') + """ + + _arrow_ext_class = make_arrow_extension_type( + "upath.UPath", pa.large_string(), b"orcapod.builtin" + ) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type( + "upath.UPath", pa.large_string(), "orcapod.builtin" + ) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "upath.UPath" + python_type: type = UPath + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``upath.UPath``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"upath.UPath"`` and storage type ``pa.large_string()``. + """ + if LogicalUPath._arrow_ext is None: + LogicalUPath._arrow_ext = LogicalUPath._arrow_ext_class() + return LogicalUPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``upath.UPath``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"upath.UPath"``. + """ + if LogicalUPath._polars_ext is None: + LogicalUPath._polars_ext = LogicalUPath._polars_ext_class() + return LogicalUPath._polars_ext + + def python_to_storage(self, value: Any) -> str: + """Convert a ``upath.UPath`` to its string representation. + + Args: + value: A ``upath.UPath`` instance. + + Returns: + The string form of the path (e.g. ``"s3://bucket/key"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any) -> UPath: + """Reconstruct a ``upath.UPath`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + + Returns: + A ``upath.UPath`` instance. + """ + return UPath(storage_value) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -k "logical_path or logical_upath" -v +``` + +Expected: All `LogicalPath` and `LogicalUPath` tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/builtin_logical_types.py \ + tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(extension_types): implement LogicalPath and LogicalUPath" +``` + +--- + +### Task 4: `LogicalUUID` implementation + +**Files:** +- Modify: `src/orcapod/extension_types/builtin_logical_types.py` +- Modify: `tests/test_extension_types/test_builtin_logical_types.py` + +- [ ] **Step 1: Write the failing tests for `LogicalUUID`** + +Append to `tests/test_extension_types/test_builtin_logical_types.py`: + +```python +# --------------------------------------------------------------------------- +# LogicalUUID tests +# --------------------------------------------------------------------------- + + +def test_logical_uuid_isinstance_logical_type(): + """LogicalUUID() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert isinstance(LogicalUUID(), LogicalType) + + +def test_logical_uuid_logical_type_name(): + """logical_type_name is 'uuid.UUID', not the Arrow extension name.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert LogicalUUID().logical_type_name == "uuid.UUID" + + +def test_logical_uuid_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert LogicalUUID().python_type is uuid_module.UUID + + +def test_logical_uuid_arrow_ext_name_is_arrow_uuid(): + """Arrow extension name is 'arrow.uuid', intentionally different from logical_type_name.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_arrow_extension_type().extension_name == "arrow.uuid" + assert lt.logical_type_name != lt.get_arrow_extension_type().extension_name + + +def test_logical_uuid_get_arrow_extension_type_returns_pa_uuid(): + """get_arrow_extension_type() returns PyArrow's built-in pa.uuid() type.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_arrow_extension_type() == pa.uuid() + + +def test_logical_uuid_get_arrow_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_uuid_get_polars_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_polars_extension_type() is lt.get_polars_extension_type() + + +def test_logical_uuid_round_trip(): + """UUID -> python_to_storage -> storage_to_python -> UUID is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.uuid4() + assert lt.storage_to_python(lt.python_to_storage(u)) == u + + +def test_logical_uuid_python_to_storage_returns_bytes(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result = lt.python_to_storage(u) + assert isinstance(result, bytes) + assert len(result) == 16 + + +def test_logical_uuid_storage_to_python_accepts_bytes(): + """storage_to_python works when storage_value is plain bytes.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + recovered = lt.storage_to_python(u.bytes) + assert recovered == u + + +def test_logical_uuid_registration_does_not_raise(): + """Registering LogicalUUID succeeds even though pa.uuid() is already in PyArrow's registry.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = LogicalTypeRegistry() + lt = LogicalUUID() + registry.register(lt) # should NOT raise + assert registry.get_by_logical_name("uuid.UUID") is lt + assert registry.get_by_arrow_extension_name("arrow.uuid") is lt +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py::test_logical_uuid_isinstance_logical_type -v +``` + +Expected: `ImportError` — `LogicalUUID` does not exist yet. + +- [ ] **Step 3: Add `LogicalUUID` to `builtin_logical_types.py`** + +Append to the end of `src/orcapod/extension_types/builtin_logical_types.py`: + +```python +class LogicalUUID: + """Logical type for ``uuid.UUID``. + + Uses PyArrow's built-in ``pa.uuid()`` extension type (``"arrow.uuid"``) + which stores UUID values as 16-byte binary (``pa.binary(16)``). + + Note: + ``logical_type_name`` (``"uuid.UUID"``) intentionally differs from + the Arrow extension name (``"arrow.uuid"``). The + ``LogicalTypeRegistry`` stores both bindings so that lookups by + either key resolve to this same instance. + + Example: + >>> import uuid + >>> lt = LogicalUUID() + >>> u = uuid.uuid4() + >>> lt.storage_to_python(lt.python_to_storage(u)) == u + True + """ + + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("arrow.uuid", pa.binary(16), None) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "uuid.UUID" + python_type: type = _uuid_module.UUID + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return PyArrow's built-in ``pa.uuid()`` extension type. + + Returns: + A cached ``pa.uuid()`` instance (Arrow extension name ``"arrow.uuid"``, + storage type ``pa.binary(16)``). + """ + if LogicalUUID._arrow_ext is None: + LogicalUUID._arrow_ext = pa.uuid() + return LogicalUUID._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``arrow.uuid``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"arrow.uuid"`` (matches the Arrow extension name, not the + logical type name). + """ + if LogicalUUID._polars_ext is None: + LogicalUUID._polars_ext = LogicalUUID._polars_ext_class() + return LogicalUUID._polars_ext + + def python_to_storage(self, value: Any) -> bytes: + """Convert a ``uuid.UUID`` to its 16-byte binary representation. + + Args: + value: A ``uuid.UUID`` instance. + + Returns: + A 16-byte ``bytes`` object (big-endian byte order, as per + ``uuid.UUID.bytes``). + """ + return value.bytes + + def storage_to_python(self, storage_value: Any) -> _uuid_module.UUID: + """Reconstruct a ``uuid.UUID`` from its 16-byte binary representation. + + Args: + storage_value: A bytes-like object of length 16. + + Returns: + A ``uuid.UUID`` instance. + """ + return _uuid_module.UUID(bytes=bytes(storage_value)) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v +``` + +Expected: All tests in the file PASS (LogicalPath, LogicalUPath, and LogicalUUID). + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/builtin_logical_types.py \ + tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(extension_types): implement LogicalUUID" +``` + +--- + +### Task 5: Wire built-in types into `DataContext` + +**Files:** +- Modify: `src/orcapod/contexts/core.py` +- Modify: `src/orcapod/contexts/registry.py` +- Modify: `src/orcapod/contexts/data/v0.1.json` +- Modify: `src/orcapod/contexts/data/schemas/context_schema.json` +- Modify: `src/orcapod/contexts/__init__.py` +- Modify: `tests/test_extension_types/test_builtin_logical_types.py` + +This task wires everything together. The integration tests are written first, but they cannot pass until the DataContext and JSON spec are updated. Do all the sub-steps in a single commit. + +- [ ] **Step 1: Write the failing integration tests** + +Append to `tests/test_extension_types/test_builtin_logical_types.py`: + +```python +# --------------------------------------------------------------------------- +# Default context integration tests +# --------------------------------------------------------------------------- + + +def test_default_context_has_logical_type_registry(): + """DataContext has a logical_type_registry attribute.""" + from orcapod.contexts import get_default_context + + ctx = get_default_context() + assert hasattr(ctx, "logical_type_registry") + + +def test_default_context_registry_has_logical_path(): + """Default registry returns LogicalPath for 'pathlib.Path'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().logical_type_registry + lt = registry.get_by_logical_name("pathlib.Path") + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_lookup_by_python_type_path(): + """Default registry routes pathlib.Path to LogicalPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().logical_type_registry + lt = registry.get_by_python_type(pathlib.Path) + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_lookup_by_arrow_name_path(): + """Default registry routes 'pathlib.Path' arrow ext name to LogicalPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().logical_type_registry + lt = registry.get_by_arrow_extension_name("pathlib.Path") + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_has_logical_upath(): + """Default registry returns LogicalUPath for 'upath.UPath'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + registry = get_default_context().logical_type_registry + lt = registry.get_by_logical_name("upath.UPath") + assert isinstance(lt, LogicalUPath) + + +def test_default_context_registry_lookup_by_python_type_upath(): + """Default registry routes UPath to LogicalUPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + registry = get_default_context().logical_type_registry + lt = registry.get_by_python_type(UPath) + assert isinstance(lt, LogicalUPath) + + +def test_default_context_registry_has_logical_uuid(): + """Default registry returns LogicalUUID for 'uuid.UUID'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = get_default_context().logical_type_registry + lt = registry.get_by_logical_name("uuid.UUID") + assert isinstance(lt, LogicalUUID) + + +def test_default_context_registry_lookup_by_arrow_name_uuid(): + """Default registry routes 'arrow.uuid' arrow ext name to LogicalUUID.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = get_default_context().logical_type_registry + lt = registry.get_by_arrow_extension_name("arrow.uuid") + assert isinstance(lt, LogicalUUID) + + +def test_default_context_registry_uuid_logical_name_differs_from_arrow_name(): + """The same LogicalUUID instance is found by both 'uuid.UUID' and 'arrow.uuid'.""" + from orcapod.contexts import get_default_context + + registry = get_default_context().logical_type_registry + by_logical = registry.get_by_logical_name("uuid.UUID") + by_arrow = registry.get_by_arrow_extension_name("arrow.uuid") + assert by_logical is by_arrow + + +def test_get_default_logical_type_registry_returns_same_as_context(): + """get_default_logical_type_registry() is the same object as get_default_context().logical_type_registry.""" + from orcapod.contexts import get_default_context, get_default_logical_type_registry + + assert get_default_logical_type_registry() is get_default_context().logical_type_registry + + +def test_default_context_idempotent_registry(): + """Calling get_default_context() twice returns the same LogicalTypeRegistry instance.""" + from orcapod.contexts import get_default_context + + r1 = get_default_context().logical_type_registry + r2 = get_default_context().logical_type_registry + assert r1 is r2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py::test_default_context_has_logical_type_registry -v +``` + +Expected: FAIL — `DataContext` has no `logical_type_registry` attribute. + +- [ ] **Step 3: Add `logical_type_registry` field to `DataContext` in `core.py`** + +Current `core.py` imports (lines 1–16): + +```python +""" +Core data structures and exceptions for the OrcaPod context system. +... +""" + +from dataclasses import dataclass + +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + SemanticHasherProtocol, +) +from orcapod.protocols.semantic_types_protocols import TypeConverterProtocol +``` + +Add one import and one field. The final `core.py` content: + +```python +""" +Core data structures and exceptions for the OrcaPod context system. + +This module defines the basic types and exceptions used throughout +the context management system. +""" + +from dataclasses import dataclass + +from orcapod.extension_types.registry import LogicalTypeRegistry +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + SemanticHasherProtocol, +) +from orcapod.protocols.semantic_types_protocols import TypeConverterProtocol + + +@dataclass +class DataContext: + """ + Data context containing all versioned components needed for data interpretation. + + A DataContext represents a specific version of the OrcaPod system configuration, + including semantic type registries, hashers, and other components that affect + how data is processed and interpreted. + + Attributes: + context_key: Unique identifier (e.g., "std:v0.1:default") + version: Version string (e.g., "v0.1") + description: Human-readable description of this context + semantic_type_registry: Registry of semantic type converters + arrow_hasher: Arrow table hasher for this context + semantic_hasher: General semantic hasher for this context + type_handler_registry: Registry of TypeHandlerProtocol instances for SemanticHasherProtocol + logical_type_registry: Registry of LogicalType instances (Path, UPath, UUID, etc.) + """ + + context_key: str + version: str + description: str + type_converter: TypeConverterProtocol + arrow_hasher: ArrowHasherProtocol + semantic_hasher: SemanticHasherProtocol # this is the currently the JSON hasher + type_handler_registry: TypeHandlerRegistry + logical_type_registry: LogicalTypeRegistry + + +class ContextValidationError(Exception): + """Raised when context validation fails.""" + + pass + + +class ContextResolutionError(Exception): + """Raised when context cannot be resolved.""" + + pass +``` + +- [ ] **Step 4: Update `contexts/registry.py` — add `logical_type_registry` to required fields and `_create_context_from_spec`** + +In `_load_spec_file` (around line 148), add `"logical_type_registry"` to `required_fields`: + +```python +required_fields = [ + "context_key", + "version", + "type_converter", + "arrow_hasher", + "semantic_hasher", + "type_handler_registry", + "logical_type_registry", +] +``` + +In `_create_context_from_spec` (around line 296), add `logical_type_registry` to the `DataContext(...)` call: + +```python +return DataContext( + context_key=context_key, + version=version, + description=description, + type_converter=ref_lut["type_converter"], + arrow_hasher=ref_lut["arrow_hasher"], + semantic_hasher=ref_lut["semantic_hasher"], + type_handler_registry=ref_lut["type_handler_registry"], + logical_type_registry=ref_lut["logical_type_registry"], +) +``` + +- [ ] **Step 5: Add `logical_type_registry` entry to `v0.1.json`** + +In `src/orcapod/contexts/data/v0.1.json`, add the following JSON block before the `"metadata"` key (after the `"semantic_hasher"` block): + +```json + "logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ] + } + }, +``` + +The full updated `v0.1.json` after the edit: + +```json +{ + "context_key": "std:v0.1:default", + "version": "v0.1", + "description": "Initial stable release with basic Path semantic type support", + "file_hasher": { + "_class": "orcapod.hashing.file_hashers.BasicFileHasher", + "_config": { + "algorithm": "sha256" + } + }, + "semantic_registry": { + "_class": "orcapod.semantic_types.semantic_registry.SemanticTypeRegistry", + "_config": { + "converters": { + "upath": { + "_class": "orcapod.semantic_types.semantic_struct_converters.UPathStructConverter", + "_config": { + "file_hasher": {"_ref": "file_hasher"} + } + }, + "path": { + "_class": "orcapod.semantic_types.semantic_struct_converters.PythonPathStructConverter", + "_config": { + "file_hasher": {"_ref": "file_hasher"} + } + } + } + } + }, + "arrow_hasher": { + "_class": "orcapod.hashing.arrow_hashers.StarfixArrowHasher", + "_config": { + "hasher_id": "arrow_v0.1", + "semantic_registry": { + "_ref": "semantic_registry" + } + } + }, + "type_converter": { + "_class": "orcapod.semantic_types.universal_converter.UniversalTypeConverter", + "_config": { + "semantic_registry": { + "_ref": "semantic_registry" + } + } + }, + "function_info_extractor": { + "_class": "orcapod.hashing.semantic_hashing.function_info_extractors.FunctionSignatureExtractor", + "_config": { + "include_module": true, + "include_defaults": true + } + }, + "type_handler_registry": { + "_class": "orcapod.hashing.semantic_hashing.type_handler_registry.TypeHandlerRegistry", + "_config": { + "handlers": [ + [{"_type": "builtins.bytes"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "builtins.bytearray"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "pathlib.Path"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.PathContentHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "upath.core.UPath"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UPathContentHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "uuid.UUID"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UUIDHandler", "_config": {}}], + [{"_type": "types.FunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.BuiltinFunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.MethodType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "builtins.type"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.TypeObjectHandler", "_config": {}}], + [{"_type": "types.GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "types.UnionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UnionTypeHandler", "_config": {}}], + [{"_type": "typing._GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "typing._SpecialForm"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.SpecialFormHandler", "_config": {}}], + [{"_type": "pyarrow.Table"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}], + [{"_type": "pyarrow.RecordBatch"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}] + ] + } + }, + "semantic_hasher": { + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher", + "_config": { + "hasher_id": "semantic_v0.1", + "type_handler_registry": { + "_ref": "type_handler_registry" + } + } + }, + "logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ] + } + }, + "metadata": { + "created_date": "2025-08-01", + "author": "OrcaPod Core Team", + "changelog": [ + "Initial release with Path semantic type support", + "Basic SHA-256 hashing for files and objects", + "Arrow logical serialization method", + "Introduced arrow_v0.1 StarfixArrowHasher using starfix ArrowDigester for cross-language-compatible Arrow hashing" + ] + } +} +``` + +- [ ] **Step 6: Add `logical_type_registry` to `context_schema.json`** + +In `src/orcapod/contexts/data/schemas/context_schema.json`: + +Add `"logical_type_registry"` to the `"required"` array (after `"type_handler_registry"`): + +```json +"required": [ + "context_key", + "version", + "semantic_registry", + "type_converter", + "arrow_hasher", + "semantic_hasher", + "type_handler_registry", + "logical_type_registry" +], +``` + +Add `"logical_type_registry"` entry to the `"properties"` object (after `"type_handler_registry"`): + +```json +"logical_type_registry": { + "$ref": "#/$defs/objectspec", + "description": "ObjectSpec for the LogicalTypeRegistry (Path, UPath, UUID built-ins)" +}, +``` + +- [ ] **Step 7: Add `get_default_logical_type_registry()` to `contexts/__init__.py`** + +In `src/orcapod/contexts/__init__.py`, add after `get_default_type_converter()`: + +```python +def get_default_logical_type_registry() -> "LogicalTypeRegistry": + """Get the default logical type registry. + + Returns: + ``LogicalTypeRegistry`` instance from the default context. + """ + return get_default_context().logical_type_registry +``` + +Add the import at the top of the file (after the `from orcapod.protocols` imports): + +```python +from orcapod.extension_types.registry import LogicalTypeRegistry +``` + +Add `"get_default_logical_type_registry"` to `__all__`. + +The updated `__all__` in `contexts/__init__.py`: + +```python +__all__ = [ + # Core types + "DataContext", + "ContextValidationError", + "ContextResolutionError", + # Main functions + "resolve_context", + "get_available_contexts", + "get_context_info", + "get_default_context", + # Convenience accessors + "get_default_semantic_hasher", + "get_default_arrow_hasher", + "get_default_type_converter", + "get_default_logical_type_registry", + # Management functions + "set_default_context_version", + "validate_all_contexts", + "reload_contexts", + # Advanced usage + "create_registry", + "JSONDataContextRegistry", +] +``` + +- [ ] **Step 8: Run the integration tests** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v +``` + +Expected: All tests PASS, including the new integration tests. + +- [ ] **Step 9: Run the full test suite to check for regressions** + +```bash +uv run pytest tests/ -v --tb=short +``` + +Expected: All previously-passing tests still PASS. The 6 `default_logical_type_registry` tests in `test_registry.py` still pass (the module-level variable is still there; we remove it next). + +- [ ] **Step 10: Commit** + +```bash +git add src/orcapod/contexts/core.py \ + src/orcapod/contexts/registry.py \ + src/orcapod/contexts/data/v0.1.json \ + src/orcapod/contexts/data/schemas/context_schema.json \ + src/orcapod/contexts/__init__.py \ + tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(contexts): add logical_type_registry to DataContext and v0.1 context" +``` + +--- + +### Task 6: Remove `default_logical_type_registry` and clean up stale tests + +**Files:** +- Modify: `src/orcapod/extension_types/__init__.py` +- Modify: `tests/test_extension_types/test_registry.py` + +The module-level `default_logical_type_registry` in `extension_types/__init__.py` is replaced by the context-scoped registry. This task removes it and deletes the 6 tests that relied on it. + +- [ ] **Step 1: Remove `default_logical_type_registry` from `extension_types/__init__.py`** + +Replace the current content of `src/orcapod/extension_types/__init__.py`: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that map +between Python objects and their Arrow/Polars extension type representation. + +Built-in registrations (``LogicalPath``, ``LogicalUPath``, ``LogicalUUID``) are +wired into ``DataContext`` via ``contexts/data/v0.1.json``. The primary access +path for the default registry is: + +- ``get_default_context().logical_type_registry`` +- ``get_default_logical_type_registry()`` (from ``orcapod.contexts``) +""" + +from __future__ import annotations + +from .protocols import LogicalType +from .registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +__all__ = [ + "LogicalType", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "make_polars_extension_type", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +- [ ] **Step 2: Remove the 6 stale `default_logical_type_registry` tests from `test_registry.py`** + +Delete the entire section at the end of `tests/test_extension_types/test_registry.py` (lines 450–532): + +```python +# --------------------------------------------------------------------------- +# default_logical_type_registry tests +# --------------------------------------------------------------------------- + +def test_logical_type_registry_module_instance(): + ... + +def test_default_registry_is_same_object_across_imports(): + ... + +def test_default_registry_register_and_lookup(): + ... + +def test_default_registry_register_idempotent(): + ... + +def test_default_registry_populates_arrow_global(): + ... + +def test_default_registry_populates_polars_global(): + ... +``` + +These tests are superseded by the integration tests in `test_builtin_logical_types.py`. + +- [ ] **Step 3: Run the full test suite** + +```bash +uv run pytest tests/ -v --tb=short +``` + +Expected: All tests PASS. The 6 removed tests no longer exist. No regressions. + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/extension_types/__init__.py \ + tests/test_extension_types/test_registry.py +git commit -m "refactor(extension_types): remove default_logical_type_registry module-level variable" +``` + +--- + +## Self-Review + +### Spec coverage check + +| Spec requirement | Covered by | +|---|---| +| `LogicalPath` implementation | Task 3 | +| `LogicalUPath` implementation | Task 3 | +| `LogicalUUID` implementation (with `pa.uuid()`) | Task 4 | +| `make_polars_extension_type` helper | Task 1 | +| `LogicalTypeRegistry.__init__` `logical_types` param | Task 2 | +| `DataContext.logical_type_registry` field | Task 5, Step 3 | +| `v0.1.json` `logical_type_registry` entry | Task 5, Step 5 | +| `context_schema.json` update | Task 5, Step 6 | +| `get_default_logical_type_registry()` convenience function | Task 5, Step 7 | +| Remove `default_logical_type_registry` from `__init__.py` | Task 6, Step 1 | +| Protocol conformance tests | Task 3 & 4 | +| Property value tests | Task 3 & 4 | +| Conversion round-trip tests | Task 3 & 4 | +| Default context registration tests | Task 5, Step 1 | +| Pre-existing Arrow type tolerance test (`LogicalUUID`) | Task 4, Step 1 | +| Idempotence test (context caching) | Task 5, Step 1 | +| UUID `logical_type_name` ≠ Arrow ext name test | Task 4, Step 1 | +| Circular import avoidance (submodule imports) | Task 3, Step 3 (in `builtin_logical_types.py`) | +| Class-level caching for extension type instances | Task 3, Step 3 & Task 4, Step 3 | +| Export `make_polars_extension_type` from `__init__.py` | Task 1, Step 4 | + +### Type consistency check + +- `make_polars_extension_type(name, arrow_storage_type, metadata)` — used consistently in Task 1 (definition) and Task 3/4 (class-body calls). +- `LogicalTypeRegistry(logical_types=[...])` — defined in Task 2, used in Task 5 JSON spec. +- `DataContext.logical_type_registry` field — added in Task 5 Step 3, passed in `_create_context_from_spec` in Task 5 Step 4. +- `get_default_logical_type_registry()` returns `LogicalTypeRegistry`, consistent with `get_default_type_converter()` pattern. +- `LogicalUUID.logical_type_name = "uuid.UUID"` vs `get_arrow_extension_type().extension_name = "arrow.uuid"` — intentional difference, tested in Task 4. + +### No placeholder scan + +All steps contain complete code or exact commands. No "TBD", "similar to", or "add validation" phrases. diff --git a/superpowers/plans/2026-06-14-plt-1668-logical-type-redesign.md b/superpowers/plans/2026-06-14-plt-1668-logical-type-redesign.md new file mode 100644 index 00000000..f807fe9f --- /dev/null +++ b/superpowers/plans/2026-06-14-plt-1668-logical-type-redesign.md @@ -0,0 +1,980 @@ +# PLT-1668: LogicalType Redesign Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace `ExtensionTypeConverter`/`ExtensionTypeRegistry` with `LogicalType`/`LogicalTypeRegistry` so that each logical type owns its Arrow and Polars extension types directly via `get_arrow_extension_type()` / `get_polars_extension_type()`, and the registry enforces a three-way binding triplet `(logical_type_name, arrow_ext_name, python_type)`. + +**Architecture:** The `LogicalType` protocol gains two new methods (`get_arrow_extension_type`, `get_polars_extension_type`) and loses three flat properties (`extension_name`, `extension_metadata`, `storage_type`). The registry drops module-level shadow dicts entirely — uniqueness is enforced per-instance via three internal dicts. A new `make_arrow_extension_type(extension_name, storage_type, metadata) -> type[pa.ExtensionType]` helper replaces the dynamic synthesis that previously lived inside the registry. + +**Tech Stack:** Python 3.12+, PyArrow ≥ 20, Polars ≥ 1.36.0, pytest, uv. + +--- + +## File Map + +| File | Action | Responsibility | +|---|---|---| +| `src/orcapod/extension_types/protocols.py` | Rewrite | `LogicalType` protocol | +| `src/orcapod/extension_types/registry.py` | Rewrite | `make_arrow_extension_type` helper + `LogicalTypeRegistry` | +| `src/orcapod/extension_types/__init__.py` | Update | Export new names + `default_logical_type_registry` | +| `src/orcapod/extension_types/schema_walker.py` | **No change** | Self-contained; no protocol imports | +| `tests/test_extension_types/test_protocols.py` | Rewrite | Protocol conformance tests | +| `tests/test_extension_types/test_registry.py` | Rewrite | Stub helpers + all registry tests | + +--- + +### Task 1: Replace `ExtensionTypeConverter` with `LogicalType` in `protocols.py` + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py` + +- [ ] **Step 1: Overwrite `protocols.py` with the `LogicalType` protocol** + +```python +# src/orcapod/extension_types/protocols.py +"""Protocol definitions for the Arrow/Polars extension type system. + +This module defines ``LogicalType`` — the contract for all implementations +that bind a Python class to its Arrow and Polars extension type representation. + +Note: + This module is part of the parallel-build phase. The old + ``SemanticStructConverterProtocol`` in ``protocols/semantic_types_protocols.py`` + is untouched; it is removed in PLT-1660. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + + +@runtime_checkable +class LogicalType(Protocol): + """Protocol for Arrow/Polars extension-type-backed logical types. + + A ``LogicalType`` is a three-way binding between a unique logical type name + (orcapod's identifier), a Python class, and Arrow/Polars extension types. + Each implementation *owns* its Arrow and Polars extension types by providing + them directly via ``get_arrow_extension_type`` and ``get_polars_extension_type``. + + This protocol is Arrow I/O only — hashing is not a logical type responsibility. + """ + + @property + def logical_type_name(self) -> str: + """Unique orcapod identifier for this logical type. + + By convention the Python FQCN (e.g. ``"uuid.UUID"``), but any unique + string is valid. Does NOT need to match the Arrow extension type name. + """ + ... + + @property + def python_type(self) -> type: + """The Python class this logical type represents.""" + ... + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for this logical type. + + ``storage_type``, ``extension_name``, and serialised metadata are + encapsulated inside the returned type; they are no longer top-level + properties on ``LogicalType``. + + For custom types: create and return an instance of a new + ``pa.ExtensionType`` subclass (e.g. via ``make_arrow_extension_type``). + For pre-existing types: return the existing instance directly + (e.g. ``pa.uuid()``). + """ + ... + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return an instance of the Polars extension type for this logical type. + + The registry calls ``type(instance)`` to obtain the class passed to + ``pl.register_extension_type``. + """ + ... + + def python_to_storage(self, value: Any) -> Any: + """Convert a Python value to its Arrow storage representation. + + Args: + value: A Python object of type ``python_type``. + + Returns: + A value suitable for use as an Arrow scalar or array element + matching the storage type of ``get_arrow_extension_type()``. + """ + ... + + def storage_to_python(self, storage_value: Any) -> Any: + """Convert an Arrow storage value back to a Python object. + + Args: + storage_value: A scalar or array element from the Arrow storage array. + + Returns: + A Python object of type ``python_type``. + """ + ... +``` + +- [ ] **Step 2: Verify old protocol tests now fail** + +```bash +cd /path/to/orcapod-python +uv run pytest tests/test_extension_types/test_protocols.py -v +``` + +Expected: FAIL — `ExtensionTypeConverter` import error and protocol checks fail. + +--- + +### Task 2: Update `test_protocols.py` for the new `LogicalType` protocol + +**Files:** +- Modify: `tests/test_extension_types/test_protocols.py` + +- [ ] **Step 1: Overwrite `test_protocols.py`** + +```python +# tests/test_extension_types/test_protocols.py +"""Tests for LogicalType protocol.""" + +from __future__ import annotations + +import pyarrow as pa +import polars as pl + +from orcapod.extension_types.protocols import LogicalType +from orcapod.extension_types.registry import make_arrow_extension_type + + +_StubArrowExtClass = make_arrow_extension_type( + "test.module.MyType", pa.large_string(), b"test.category" +) + + +class _StubLogicalType: + """Minimal conforming implementation of LogicalType for use in tests.""" + + @property + def logical_type_name(self) -> str: + return "test.module.MyType" + + @property + def python_type(self) -> type: + return str + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return _StubArrowExtClass() + + def get_polars_extension_type(self) -> pl.BaseExtension: + class _StubPL(pl.BaseExtension): + def __init__(self) -> None: + super().__init__("test.module.MyType", pl.String, None) + + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + return _StubPL() + + def python_to_storage(self, value): + return str(value) + + def storage_to_python(self, storage_value): + return storage_value + + +def test_protocol_is_importable(): + """LogicalType can be imported from extension_types.protocols.""" + assert LogicalType is not None + + +def test_protocol_defines_required_members(): + """A conforming class is recognized as a LogicalType instance.""" + assert isinstance(_StubLogicalType(), LogicalType) + + +def test_conforming_class_satisfies_protocol(): + """A class implementing all required members works correctly via the protocol interface.""" + lt: LogicalType = _StubLogicalType() + assert lt.logical_type_name == "test.module.MyType" + assert lt.python_type is str + assert lt.get_arrow_extension_type().extension_name == "test.module.MyType" + assert isinstance(lt.get_polars_extension_type(), pl.BaseExtension) + assert lt.python_to_storage(42) == "42" + assert lt.storage_to_python("hello") == "hello" +``` + +Note: `make_arrow_extension_type` is imported from `registry.py` — this task depends on Task 3 below having the helper in place before this test file is runnable. Write the file now; run after Task 3. + +--- + +### Task 3: Add `make_arrow_extension_type` and `LogicalTypeRegistry` to `registry.py` + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` + +- [ ] **Step 1: Overwrite `registry.py` with the new implementation** + +```python +# src/orcapod/extension_types/registry.py +"""Registry for LogicalType instances. + +Registering a logical type automatically registers the corresponding +extension type in both PyArrow's and Polars' global registries. +""" + +from __future__ import annotations + +import re + +import polars as pl +import pyarrow as pa + +from orcapod.extension_types.protocols import LogicalType + + +def _sanitize(name: str) -> str: + """Replace non-alphanumeric characters with underscores. + + Used to produce a valid Python identifier for the dynamically created + ``pa.ExtensionType`` subclass name. + """ + return re.sub(r"[^A-Za-z0-9]", "_", name) + + +def make_arrow_extension_type( + extension_name: str, + storage_type: pa.DataType, + metadata: bytes | None = None, +) -> type[pa.ExtensionType]: + """Synthesise and return a ``pa.ExtensionType`` subclass. + + Returns the *class*, not an instance — callers instantiate it inside their + ``get_arrow_extension_type()`` implementation. Returning the class preserves + the option to create multiple instances or future parameterised variants from + the same class. + + This is a low-level building block. The full pattern for binding a Python + type to a specific Arrow/Polars representation — the extension type factory — + is the responsibility of each ``LogicalType`` implementation. See PLT-1656 + for the built-in implementations (``Path``, ``UPath``, ``UUID``). + + Args: + extension_name: The Arrow extension name (``ARROW:extension:name``). + storage_type: The underlying Arrow storage type. + metadata: Optional bytes stored as ``ARROW:extension:metadata``. + Defaults to ``None`` (serialised as empty bytes). + + Returns: + A ``pa.ExtensionType`` subclass. Call it with no arguments to obtain + an instance suitable for passing to ``pa.register_extension_type`` or + returning from ``get_arrow_extension_type()``. + """ + _name, _storage, _metadata = extension_name, storage_type, metadata or b"" + return type( + f"_ArrowExt_{_sanitize(extension_name)}", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _storage, _name), + "__arrow_ext_serialize__": lambda self: _metadata, + # __arrow_ext_deserialize__ reconstructs the type descriptor from schema + # metadata (called once per IPC/Parquet read, not per value). The storage + # type and metadata are baked into the constructor via closure, so + # arguments are intentionally ignored. + "__arrow_ext_deserialize__": classmethod( + lambda cls, storage_type, serialized: cls() + ), + }, + ) + + +class LogicalTypeRegistry: + """Registry for ``LogicalType`` instances. + + Maintains a three-way binding: ``(logical_type_name, arrow_extension_name, + python_type)`` → ``LogicalType``. Each key participates in at most one + binding within a registry instance. + + Registering a logical type side-effect-registers the corresponding extension + type in PyArrow's and Polars' global registries. Pre-existing types (those + already registered externally, e.g. PyArrow's built-in ``"arrow.uuid"``) are + accepted silently — the binding is stored without error. + + The process-global ``default_logical_type_registry`` instance provides + effective process-wide uniqueness for normal use. Thread-safety is deferred. + + Example: + >>> registry = LogicalTypeRegistry() + >>> registry.register(my_logical_type) + >>> lt = registry.get_by_logical_name("uuid.UUID") + """ + + def __init__(self) -> None: + self._by_logical_name: dict[str, LogicalType] = {} + self._by_arrow_name: dict[str, LogicalType] = {} + self._by_python_type: dict[type, LogicalType] = {} + + def register(self, logical_type: LogicalType) -> None: + """Register *logical_type* and its PyArrow/Polars extension types. + + Args: + logical_type: A ``LogicalType`` instance to register. + + Raises: + ValueError: If any of the three keys (``logical_type_name``, + Arrow extension name, ``python_type``) is already bound to a + *different* ``LogicalType`` in this registry. + """ + arrow_ext_name = logical_type.get_arrow_extension_type().extension_name + py_type = logical_type.python_type + logical_name = logical_type.logical_type_name + + existing_by_logical = self._by_logical_name.get(logical_name) + existing_by_arrow = self._by_arrow_name.get(arrow_ext_name) + existing_by_python = self._by_python_type.get(py_type) + + # Triplet conflict check: raise if any key is bound to a different instance. + for existing, label, key in [ + (existing_by_logical, "logical_type_name", logical_name), + (existing_by_arrow, "arrow_extension_name", arrow_ext_name), + (existing_by_python, "python_type", py_type.__qualname__), + ]: + if existing is not None and existing is not logical_type: + raise ValueError( + f"Cannot register logical type '{logical_name}': " + f"{label} {key!r} is already bound to " + f"'{existing.logical_type_name}'." + ) + + # Idempotent check: all three keys already bound to this same instance. + if ( + existing_by_logical is logical_type + and existing_by_arrow is logical_type + and existing_by_python is logical_type + ): + return + + # Register Arrow extension type. ArrowKeyError means the name is already + # in PyArrow's global registry (pre-existing type or another registry + # instance). Accept silently — PLT-1669 adds post-error validation. + try: + pa.register_extension_type(logical_type.get_arrow_extension_type()) + except pa.lib.ArrowKeyError: + pass + + # Register Polars extension type. ValueError means already registered. + polars_ext_class = type(logical_type.get_polars_extension_type()) + try: + pl.register_extension_type(arrow_ext_name, polars_ext_class) + except ValueError: + pass + + # Store three-way binding. + self._by_logical_name[logical_name] = logical_type + self._by_arrow_name[arrow_ext_name] = logical_type + self._by_python_type[py_type] = logical_type + + def get_by_logical_name(self, name: str) -> LogicalType | None: + """Return the logical type registered under *name*, or ``None``.""" + return self._by_logical_name.get(name) + + def get_by_python_type(self, python_type: type) -> LogicalType | None: + """Return the logical type for *python_type*, or ``None``. + + Checks exact match first, then falls back to an ``issubclass`` scan. + When multiple registered types are superclasses of *python_type*, the + one registered first wins (insertion-order dict, Python 3.7+). + """ + lt = self._by_python_type.get(python_type) + if lt is not None: + return lt + for registered_type, lt in self._by_python_type.items(): + if issubclass(python_type, registered_type): + return lt + return None + + def get_by_arrow_extension_name(self, arrow_name: str) -> LogicalType | None: + """Return the logical type registered under *arrow_name*, or ``None``.""" + return self._by_arrow_name.get(arrow_name) +``` + +- [ ] **Step 2: Run protocol tests (both tasks together)** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v +``` + +Expected: All 3 tests PASS. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py \ + src/orcapod/extension_types/registry.py \ + tests/test_extension_types/test_protocols.py +git commit -m "feat(extension_types): add LogicalType protocol and LogicalTypeRegistry (PLT-1668)" +``` + +--- + +### Task 4: Rework `test_registry.py` — stubs + basic tests + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Replace the imports and stub helpers at the top of `test_registry.py`** + +Replace everything from the top of the file through the `_make_stub` function definition (roughly lines 1–65 in the original) with: + +```python +"""Tests for LogicalTypeRegistry.""" + +from __future__ import annotations + +import pathlib +import tempfile +import uuid +import warnings + +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from orcapod.extension_types.protocols import LogicalType +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + """Unique extension/logical name to avoid cross-test global-registry collisions.""" + return f"test.registry.{uuid.uuid4().hex[:8]}" + + +def _make_stub( + logical_name: str | None = None, + arrow_name: str | None = None, + storage: pa.DataType | None = None, + metadata: bytes | None = b"test.category", + py_type: type = str, +) -> LogicalType: + """Factory for minimal LogicalType conforming stubs. + + ``arrow_name`` defaults to ``logical_name`` when omitted. Pass separate + values to test cases that need a distinct Arrow extension name. + """ + _logical_name = logical_name or _unique_name() + _arrow_name = arrow_name or _logical_name + _storage = storage if storage is not None else pa.large_utf8() + _ArrowExt = make_arrow_extension_type(_arrow_name, _storage, metadata) + _pl_storage = pl.from_arrow(pa.array([], type=_storage)).dtype + _meta_str = metadata.decode("utf-8") if metadata else None + + class _StubPL(pl.BaseExtension): + def __init__(self) -> None: + super().__init__(_arrow_name, _pl_storage, _meta_str) + + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _Stub: + @property + def logical_type_name(self) -> str: + return _logical_name + + @property + def python_type(self) -> type: + return py_type + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return _ArrowExt() + + def get_polars_extension_type(self) -> pl.BaseExtension: + return _StubPL() + + def python_to_storage(self, value): + return str(value) + + def storage_to_python(self, storage_value): + return storage_value + + return _Stub() +``` + +- [ ] **Step 2: Replace all basic/lookup/PA/Polars/module-level tests (lines 70–436 in the original) with the updated equivalents** + +Remove all tests that reference removed methods (`has_extension_name`, `has_python_type`, `list_extension_names`, `list_python_types`, `get_converter_for_name`, `get_converter_for_python_type`). Replace with: + +```python +# --------------------------------------------------------------------------- +# Basic registration tests +# --------------------------------------------------------------------------- + +def test_register_stores_three_way_binding(): + """After register(), all three lookup methods return the registered LogicalType.""" + stub = _make_stub() + registry = LogicalTypeRegistry() + registry.register(stub) + + arrow_name = stub.get_arrow_extension_type().extension_name + assert registry.get_by_logical_name(stub.logical_type_name) is stub + assert registry.get_by_arrow_extension_name(arrow_name) is stub + assert registry.get_by_python_type(stub.python_type) is stub + + +def test_get_by_logical_name_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_logical_name("does.not.exist") is None + + +def test_get_by_python_type_exact(): + registry = LogicalTypeRegistry() + stub = _make_stub(py_type=bytes) + registry.register(stub) + assert registry.get_by_python_type(bytes) is stub + + +def test_get_by_python_type_subclass(): + class _Base: + pass + + class _Child(_Base): + pass + + registry = LogicalTypeRegistry() + stub = _make_stub(py_type=_Base) + registry.register(stub) + assert registry.get_by_python_type(_Child) is stub + + +def test_get_by_python_type_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_python_type(int) is None + + +def test_get_by_arrow_extension_name_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_arrow_extension_name("does.not.exist") is None + + +# --------------------------------------------------------------------------- +# Idempotency +# --------------------------------------------------------------------------- + +def test_register_idempotent_same_instance(): + """Registering the same LogicalType object twice is a no-op.""" + stub = _make_stub() + registry = LogicalTypeRegistry() + registry.register(stub) + registry.register(stub) # should not raise + assert registry.get_by_logical_name(stub.logical_type_name) is stub + + +# --------------------------------------------------------------------------- +# Triplet conflict tests +# --------------------------------------------------------------------------- + +def test_triplet_conflict_same_logical_name_raises(): + """Two LogicalTypes sharing logical_type_name -> ValueError.""" + logical_name = _unique_name() + stub1 = _make_stub(logical_name=logical_name, py_type=str) + stub2 = _make_stub(logical_name=logical_name, py_type=int) + + registry = LogicalTypeRegistry() + registry.register(stub1) + with pytest.raises(ValueError, match=logical_name): + registry.register(stub2) + + +def test_triplet_conflict_same_arrow_name_raises(): + """Two LogicalTypes sharing Arrow extension name -> ValueError.""" + shared_arrow_name = _unique_name() + stub1 = _make_stub(arrow_name=shared_arrow_name, py_type=str) + stub2 = _make_stub(arrow_name=shared_arrow_name, py_type=int) + + registry = LogicalTypeRegistry() + registry.register(stub1) + with pytest.raises(ValueError, match=shared_arrow_name): + registry.register(stub2) + + +def test_triplet_conflict_same_python_type_raises(): + """Two LogicalTypes sharing python_type -> ValueError.""" + stub1 = _make_stub(py_type=float) + stub2 = _make_stub(py_type=float) + + registry = LogicalTypeRegistry() + registry.register(stub1) + with pytest.raises(ValueError, match="float"): + registry.register(stub2) + + +# --------------------------------------------------------------------------- +# Pre-existing type tolerance tests +# --------------------------------------------------------------------------- + +def test_register_preexisting_arrow_type_succeeds(): + """ArrowKeyError from PA global registry is accepted silently; binding is stored.""" + name = _unique_name() + + class _ExternalPA(pa.ExtensionType): + def __init__(self) -> None: + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_ExternalPA()) # pre-register externally + + stub = _make_stub(arrow_name=name) + registry = LogicalTypeRegistry() + registry.register(stub) # must not raise + + assert registry.get_by_logical_name(stub.logical_type_name) is stub + assert registry.get_by_arrow_extension_name(name) is stub + assert registry.get_by_python_type(stub.python_type) is stub + + +def test_register_preexisting_polars_type_succeeds(): + """ValueError from Polars global registry is accepted silently; binding is stored.""" + name = _unique_name() + + # Pre-register in PA first to avoid PA-level conflict + class _ExternalPA(pa.ExtensionType): + def __init__(self) -> None: + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_ExternalPA()) + + class _ExternalPL(pl.BaseExtension): + def __init__(self) -> None: + super().__init__(name, pl.String, None) + + @classmethod + def ext_from_params(cls, n, s, m): + return cls() + + pl.register_extension_type(name, _ExternalPL) + + stub = _make_stub(arrow_name=name) + registry = LogicalTypeRegistry() + registry.register(stub) # must not raise + + assert registry.get_by_logical_name(stub.logical_type_name) is stub + assert registry.get_by_arrow_extension_name(name) is stub + assert registry.get_by_python_type(stub.python_type) is stub + + +# --------------------------------------------------------------------------- +# PyArrow global registry: our type gets registered +# --------------------------------------------------------------------------- + +def test_register_populates_arrow_global_registry(): + """After register(), PA global registry contains the extension type.""" + stub = _make_stub() + registry = LogicalTypeRegistry() + registry.register(stub) + + arrow_name = stub.get_arrow_extension_type().extension_name + + class _Probe(pa.ExtensionType): + def __init__(self) -> None: + pa.ExtensionType.__init__(self, pa.large_utf8(), arrow_name) + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + with pytest.raises(pa.lib.ArrowKeyError): + pa.register_extension_type(_Probe()) +``` + +- [ ] **Step 3: Run the basic + idempotency + triplet + pre-existing tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v -k "not round_trip and not parquet and not module_instance" +``` + +Expected: All newly written tests PASS. + +- [ ] **Step 4: Commit** + +```bash +git add tests/test_extension_types/test_registry.py +git commit -m "test(extension_types): rework test_registry for LogicalTypeRegistry (PLT-1668)" +``` + +--- + +### Task 5: Update end-to-end tests in `test_registry.py` + +**Files:** +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Replace the `_Color`, `_make_color_converter`, `_build_ext_array`, and end-to-end test functions** + +Remove the old `_Color` / `_make_color_converter` / `_build_ext_array` block and the three round-trip tests. Replace with: + +```python +# --------------------------------------------------------------------------- +# End-to-end helpers +# --------------------------------------------------------------------------- + +class _Color: + """Minimal Python class used to exercise the logical type contract end-to-end.""" + + def __init__(self, hex_str: str) -> None: + self.hex_str = hex_str + + def __eq__(self, other: object) -> bool: + return isinstance(other, _Color) and self.hex_str == other.hex_str + + def __repr__(self) -> str: + return f"Color({self.hex_str!r})" + + +def _make_color_logical_type() -> LogicalType: + """LogicalType for _Color, backed by pa.large_utf8() storage.""" + _name = _unique_name() + _ArrowExt = make_arrow_extension_type(_name, pa.large_utf8(), b"test.color") + _pl_storage = pl.from_arrow(pa.array([], type=pa.large_utf8())).dtype + + class _ColorPL(pl.BaseExtension): + def __init__(self) -> None: + super().__init__(_name, _pl_storage, "test.color") + + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _ColorLogicalType: + @property + def logical_type_name(self) -> str: + return _name + + @property + def python_type(self) -> type: + return _Color + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return _ArrowExt() + + def get_polars_extension_type(self) -> pl.BaseExtension: + return _ColorPL() + + def python_to_storage(self, value: _Color) -> str: + return value.hex_str + + def storage_to_python(self, storage_value: str) -> _Color: + return _Color(storage_value) + + return _ColorLogicalType() + + +def _build_ext_array(lt: LogicalType, values: list) -> pa.Array: + """Build a PA extension array from Python values using the logical type.""" + arrow_ext = lt.get_arrow_extension_type() + storage_values = [lt.python_to_storage(v) for v in values] + storage_arr = pa.array(storage_values, type=arrow_ext.storage_type) + return storage_arr.cast(arrow_ext) + + +# --------------------------------------------------------------------------- +# End-to-end integration tests +# --------------------------------------------------------------------------- + +def test_python_class_round_trip(): + """Python objects -> Arrow extension array -> Python objects via logical type methods.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register(lt) + + originals = [_Color("#ff0000"), _Color("#00ff00"), _Color("#0000ff")] + ext_arr = _build_ext_array(lt, originals) + + recovered = [lt.storage_to_python(v.as_py()) for v in ext_arr.storage] + assert recovered == originals + + +def test_arrow_polars_round_trip(): + """PA ext array -> pl.from_arrow -> to_arrow() preserves extension type and values.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register(lt) + + originals = [_Color("#aabbcc"), _Color("#112233")] + ext_arr = _build_ext_array(lt, originals) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pl_series = pl.from_arrow(ext_arr) + + arrow_name = lt.get_arrow_extension_type().extension_name + assert isinstance(pl_series.dtype, pl.BaseExtension) + assert pl_series.dtype.ext_name() == arrow_name + + arr_back = pl_series.to_arrow() + assert arr_back.type.extension_name == arrow_name + + recovered = [lt.storage_to_python(v.as_py()) for v in arr_back.storage] + assert recovered == originals + + +def test_parquet_round_trip(): + """PA ext array -> Parquet -> read back; extension type and values preserved.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register(lt) + + originals = [_Color("#deadbe"), _Color("#cafeba")] + ext_arr = _build_ext_array(lt, originals) + schema = pa.schema([pa.field("color", ext_arr.type), pa.field("id", pa.int32())]) + table = pa.table( + {"color": ext_arr, "id": pa.array([1, 2], type=pa.int32())}, + schema=schema, + ) + + with tempfile.TemporaryDirectory() as tmp: + path = pathlib.Path(tmp) / "test.parquet" + pq.write_table(table, path) + table_back = pq.read_table(path) + + arrow_name = lt.get_arrow_extension_type().extension_name + assert table_back.schema.field("color").type.extension_name == arrow_name + storage_arr = table_back.column("color").combine_chunks().storage + recovered = [lt.storage_to_python(v.as_py()) for v in storage_arr] + assert recovered == originals + + +# --------------------------------------------------------------------------- +# Module-level instance test +# --------------------------------------------------------------------------- + +def test_logical_type_registry_module_instance(): + """extension_types.default_logical_type_registry is a LogicalTypeRegistry, starts empty.""" + from orcapod import extension_types + + assert isinstance(extension_types.default_logical_type_registry, LogicalTypeRegistry) + # PLT-1668 scope: no built-in logical types registered yet (that is PLT-1656). + assert extension_types.default_logical_type_registry.get_by_logical_name("uuid.UUID") is None +``` + +- [ ] **Step 2: Run all registry tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: All tests PASS. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_registry.py +git commit -m "test(extension_types): add end-to-end and module-instance tests for LogicalTypeRegistry (PLT-1668)" +``` + +--- + +### Task 6: Update `__init__.py` exports + +**Files:** +- Modify: `src/orcapod/extension_types/__init__.py` + +- [ ] **Step 1: Overwrite `__init__.py`** + +```python +# src/orcapod/extension_types/__init__.py +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that bind +Python classes to their Arrow and Polars extension type representation. + +The module-level ``default_logical_type_registry`` instance is the process default. +Built-in registrations (``Path``, ``UPath``, ``UUID``) are added by PLT-1656. +``DataContext`` wiring is added by PLT-1660. +""" + +from __future__ import annotations + +from .protocols import LogicalType +from .registry import LogicalTypeRegistry, make_arrow_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +default_logical_type_registry = LogicalTypeRegistry() + +__all__ = [ + "LogicalType", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +- [ ] **Step 2: Run the full `test_extension_types` suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: All tests in `test_protocols.py`, `test_registry.py`, and `test_schema_walker.py` PASS. + +- [ ] **Step 3: Run the complete test suite to catch any regressions** + +```bash +uv run pytest --tb=short -q +``` + +Expected: All tests pass. No references to `ExtensionTypeConverter`, `ExtensionTypeRegistry`, or `default_extension_type_registry` remain outside of the deleted/replaced files. + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/extension_types/__init__.py +git commit -m "feat(extension_types): update __init__ exports for LogicalType redesign (PLT-1668)" +``` + +--- + +## Self-Review Checklist + +After completing all tasks, verify: + +- [ ] `LogicalType` has exactly 6 members: `logical_type_name`, `python_type`, `get_arrow_extension_type`, `get_polars_extension_type`, `python_to_storage`, `storage_to_python` +- [ ] `LogicalTypeRegistry` has exactly 3 lookup methods: `get_by_logical_name`, `get_by_python_type`, `get_by_arrow_extension_name` +- [ ] No reference to `ExtensionTypeConverter`, `ExtensionTypeRegistry`, `default_extension_type_registry`, `_ARROW_REGISTRY`, `_POLARS_REGISTRY`, `_register_arrow_ext_type`, or `_register_polars_ext_type` remains anywhere in `src/` or `tests/` +- [ ] `make_arrow_extension_type` returns `type[pa.ExtensionType]` (a class, not an instance) +- [ ] Triplet conflict error messages include the conflicting key name so `pytest.raises(ValueError, match=)` works +- [ ] Pre-existing-type tests pre-register externally then call `registry.register()` — the call must not raise +- [ ] `test_schema_walker.py` still passes unchanged diff --git a/superpowers/plans/2026-06-15-plt-1670-orcapod-namespace-builtin-types.md b/superpowers/plans/2026-06-15-plt-1670-orcapod-namespace-builtin-types.md new file mode 100644 index 00000000..4f10b9ee --- /dev/null +++ b/superpowers/plans/2026-06-15-plt-1670-orcapod-namespace-builtin-types.md @@ -0,0 +1,721 @@ +# PLT-1670: Namespace Built-in Extension Types under `orcapod.*` Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Rename the three built-in Arrow extension types from upstream module-path names (`"pathlib.Path"`, `"upath.UPath"`, `"uuid.UUID"`) to Orcapod-owned namespaced names (`"orcapod.path"`, `"orcapod.upath"`, `"orcapod.uuid"`), and expose `Path`, `UPath`, `UUID` type aliases at the top-level `orcapod` namespace. + +**Architecture:** The three `LogicalType` classes in `builtin_logical_types.py` each carry a string constant used as both `logical_type_name` and the Arrow/Polars extension name — changing those constants is the entirety of the rename. Top-level aliases are simple re-exports in `__init__.py` that expose the upstream types under an Orcapod-stable symbol. Tests are updated last to match reality after the TDD red → green cycle. + +**Tech Stack:** Python, PyArrow (extension types), Polars (extension types), pytest via `uv run pytest`. + +--- + +## File Map + +| File | Action | What changes | +|------|--------|-------------| +| `tests/test_extension_types/test_builtin_logical_types.py` | Modify | Update 13 string assertions from old extension names to new `orcapod.*` names; add 5 alias tests | +| `src/orcapod/extension_types/builtin_logical_types.py` | Modify | Rename 6 string constants (2 per class: `_arrow_ext_class`, `_polars_ext_class`), 3 `logical_type_name` class attributes, update module and class docstrings | +| `src/orcapod/__init__.py` | Modify | Add `Path`, `UPath`, `UUID` re-exports with stability docstring; add to `__all__` | + +No other files need to change — the context config (`contexts/data/v0.1.json`) references classes by dotted name, not by extension name string, so it is unaffected. + +--- + +## Task 1: Update tests to assert on new `orcapod.*` extension names + +This is the TDD red step. After this task the test suite will fail with assertion errors until Task 2 fixes the implementation. + +**Files:** +- Modify: `tests/test_extension_types/test_builtin_logical_types.py` + +- [ ] **Step 1: Run the current test suite to confirm it is green before any changes** + +```bash +cd /path/to/orcapod-python +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v --tb=short 2>&1 | tail -20 +``` + +Expected: all tests pass (green baseline). + +- [ ] **Step 2: Update `test_logical_path_logical_type_name` (line 31)** + +Change: +```python +assert LogicalPath().logical_type_name == "pathlib.Path" +``` +To: +```python +assert LogicalPath().logical_type_name == "orcapod.path" +``` + +- [ ] **Step 3: Update `test_logical_path_arrow_ext_name` (line 44)** + +Change: +```python +assert LogicalPath().get_arrow_extension_type().extension_name == "pathlib.Path" +``` +To: +```python +assert LogicalPath().get_arrow_extension_type().extension_name == "orcapod.path" +``` + +- [ ] **Step 4: Update `test_logical_upath_logical_type_name` (line 103)** + +Change: +```python +assert LogicalUPath().logical_type_name == "upath.UPath" +``` +To: +```python +assert LogicalUPath().logical_type_name == "orcapod.upath" +``` + +- [ ] **Step 5: Update `test_logical_upath_arrow_ext_name` (line 116)** + +Change: +```python +assert LogicalUPath().get_arrow_extension_type().extension_name == "upath.UPath" +``` +To: +```python +assert LogicalUPath().get_arrow_extension_type().extension_name == "orcapod.upath" +``` + +- [ ] **Step 6: Update `test_logical_uuid_logical_type_name` (line 173)** + +Change: +```python +assert LogicalUUID().logical_type_name == "uuid.UUID" +``` +To: +```python +assert LogicalUUID().logical_type_name == "orcapod.uuid" +``` + +- [ ] **Step 7: Update `test_logical_uuid_arrow_ext_name` (lines 187–188)** + +Change: +```python + assert lt.get_arrow_extension_type().extension_name == "uuid.UUID" + assert lt.get_arrow_extension_type().extension_name == lt.logical_type_name +``` +To: +```python + assert lt.get_arrow_extension_type().extension_name == "orcapod.uuid" + assert lt.get_arrow_extension_type().extension_name == lt.logical_type_name +``` + +- [ ] **Step 8: Update `test_logical_uuid_registration_does_not_raise` (lines 248–249)** + +Change: +```python + assert registry.get_by_logical_name("uuid.UUID") is lt + assert registry.get_by_arrow_extension_name("uuid.UUID") is lt +``` +To: +```python + assert registry.get_by_logical_name("orcapod.uuid") is lt + assert registry.get_by_arrow_extension_name("orcapod.uuid") is lt +``` + +- [ ] **Step 9: Update default-context tests — `test_default_context_registry_has_logical_path` (line 384)** + +Change: +```python + lt = registry.get_by_logical_name("pathlib.Path") +``` +To: +```python + lt = registry.get_by_logical_name("orcapod.path") +``` + +- [ ] **Step 10: Update `test_default_context_registry_lookup_by_arrow_name_path` (line 404)** + +Change: +```python + lt = registry.get_by_arrow_extension_name("pathlib.Path") +``` +To: +```python + lt = registry.get_by_arrow_extension_name("orcapod.path") +``` + +- [ ] **Step 11: Update `test_default_context_registry_has_logical_upath` (line 414)** + +Change: +```python + lt = registry.get_by_logical_name("upath.UPath") +``` +To: +```python + lt = registry.get_by_logical_name("orcapod.upath") +``` + +- [ ] **Step 12: Update `test_default_context_registry_has_logical_uuid` (line 434)** + +Change: +```python + lt = registry.get_by_logical_name("uuid.UUID") +``` +To: +```python + lt = registry.get_by_logical_name("orcapod.uuid") +``` + +- [ ] **Step 13: Update `test_default_context_registry_lookup_by_arrow_name_uuid` (line 444)** + +Change: +```python + lt = registry.get_by_arrow_extension_name("uuid.UUID") +``` +To: +```python + lt = registry.get_by_arrow_extension_name("orcapod.uuid") +``` + +- [ ] **Step 14: Run tests to confirm they are now red** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v --tb=line 2>&1 | grep -E "FAILED|PASSED|ERROR" | head -30 +``` + +Expected: 13 tests now fail with `AssertionError`; all others pass. + +--- + +## Task 2: Rename extension type strings in `builtin_logical_types.py` + +This makes the red tests green. All 6 extension-name string constants, all 3 `logical_type_name` class attributes, and docstrings are updated to the `orcapod.*` namespace. + +**Files:** +- Modify: `src/orcapod/extension_types/builtin_logical_types.py` + +- [ ] **Step 1: Update the module-level docstring** + +Change the opening docstring lines 6–8 from: +```python +- ``LogicalPath``: maps ``pathlib.Path`` ↔ Arrow large_string extension "pathlib.Path" +- ``LogicalUPath``: maps ``upath.UPath`` ↔ Arrow large_string extension "upath.UPath" +- ``LogicalUUID``: maps ``uuid.UUID`` ↔ Arrow large_binary extension "uuid.UUID" +``` +To: +```python +- ``LogicalPath``: maps ``pathlib.Path`` ↔ Arrow large_string extension ``"orcapod.path"`` +- ``LogicalUPath``: maps ``upath.UPath`` ↔ Arrow large_string extension ``"orcapod.upath"`` +- ``LogicalUUID``: maps ``uuid.UUID`` ↔ Arrow large_binary extension ``"orcapod.uuid"`` +``` + +And replace the full module docstring with the updated version that adds the stability rationale note: + +```python +"""Built-in LogicalType implementations for orcapod. + +Provides three built-in logical types registered into the default +``DataContext.logical_type_registry`` via ``contexts/data/v0.1.json``: + +- ``LogicalPath``: maps ``pathlib.Path`` ↔ Arrow large_string extension ``"orcapod.path"`` +- ``LogicalUPath``: maps ``upath.UPath`` ↔ Arrow large_string extension ``"orcapod.upath"`` +- ``LogicalUUID``: maps ``uuid.UUID`` ↔ Arrow large_binary extension ``"orcapod.uuid"`` + +All three types use the ``orcapod.*`` extension name namespace rather than the upstream +module-qualified names (``"pathlib.Path"``, etc.). This gives Orcapod stable ownership of +the on-disk extension identity: even if the upstream library is renamed or restructured, +data written with these extension names continues to be readable without modification. + +Note: + All imports from orcapod.extension_types use direct submodule paths + (e.g. ``from orcapod.extension_types.registry import ...``) rather than + the package ``__init__`` to avoid circular imports when the context system + loads this module at startup. +""" +``` + +- [ ] **Step 2: Update `LogicalPath` class — class attributes and docstrings** + +Replace the `LogicalPath` class definition (lines 30–94) with: + +```python +class LogicalPath: + """Logical type for ``pathlib.Path``. + + Stores paths as Arrow large strings using the custom extension type + ``"orcapod.path"``. + + The extension name ``"orcapod.path"`` is Orcapod-owned and stable; it does not + depend on the upstream ``pathlib`` module path. Use ``orcapod.Path`` (a top-level + alias for ``pathlib.Path``) as the preferred way to reference this type in user code. + + Example: + >>> lt = LogicalPath() + >>> lt.python_to_storage(pathlib.Path("/tmp/foo")) + '/tmp/foo' + >>> lt.storage_to_python('/tmp/foo') + PosixPath('/tmp/foo') + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.path", pa.large_string()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.path", pa.large_string()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.path" + python_type: type = pathlib.Path + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``pathlib.Path``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.path"`` and storage type ``pa.large_string()``. + """ + if LogicalPath._arrow_ext is None: + LogicalPath._arrow_ext = LogicalPath._arrow_ext_class() + return LogicalPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``pathlib.Path``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.path"``. + """ + if LogicalPath._polars_ext is None: + LogicalPath._polars_ext = LogicalPath._polars_ext_class() + return LogicalPath._polars_ext + + def python_to_storage(self, value: Any) -> str: + """Convert a ``pathlib.Path`` to its string representation. + + Args: + value: A ``pathlib.Path`` instance. + + Returns: + The string form of the path (e.g. ``"/tmp/foo"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any) -> pathlib.Path: + """Reconstruct a ``pathlib.Path`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + + Returns: + A ``pathlib.Path`` instance. + """ + return pathlib.Path(storage_value) +``` + +- [ ] **Step 3: Update `LogicalUPath` class — class attributes and docstrings** + +Replace the `LogicalUPath` class definition (lines 97–161) with: + +```python +class LogicalUPath: + """Logical type for ``upath.UPath``. + + Stores paths as Arrow large strings using the custom extension type + ``"orcapod.upath"``. + + The extension name ``"orcapod.upath"`` is Orcapod-owned and stable; it does not + depend on the upstream ``upath`` module path. Use ``orcapod.UPath`` (a top-level + alias for ``upath.UPath``) as the preferred way to reference this type in user code. + + Example: + >>> lt = LogicalUPath() + >>> lt.python_to_storage(UPath("s3://bucket/key")) + 's3://bucket/key' + >>> lt.storage_to_python("s3://bucket/key") + UPath('s3://bucket/key') + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.upath", pa.large_string()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.upath", pa.large_string()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.upath" + python_type: type = UPath + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``upath.UPath``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.upath"`` and storage type ``pa.large_string()``. + """ + if LogicalUPath._arrow_ext is None: + LogicalUPath._arrow_ext = LogicalUPath._arrow_ext_class() + return LogicalUPath._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``upath.UPath``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.upath"``. + """ + if LogicalUPath._polars_ext is None: + LogicalUPath._polars_ext = LogicalUPath._polars_ext_class() + return LogicalUPath._polars_ext + + def python_to_storage(self, value: Any) -> str: + """Convert a ``upath.UPath`` to its string representation. + + Args: + value: A ``upath.UPath`` instance. + + Returns: + The string form of the path (e.g. ``"s3://bucket/key"``). + """ + return str(value) + + def storage_to_python(self, storage_value: Any) -> UPath: + """Reconstruct a ``upath.UPath`` from its string representation. + + Args: + storage_value: A string path as stored in Arrow. + + Returns: + A ``upath.UPath`` instance. + """ + return UPath(storage_value) +``` + +- [ ] **Step 4: Update `LogicalUUID` class — class attributes and docstrings** + +Replace the `LogicalUUID` class definition (lines 164–236) with: + +```python +class LogicalUUID: + """Logical type for ``uuid.UUID``. + + Stores UUIDs as Arrow binary (16 bytes) using the custom extension type + ``"orcapod.uuid"``. Both the Arrow extension name and ``logical_type_name`` + are ``"orcapod.uuid"``, consistent with ``LogicalPath`` and ``LogicalUPath``. + + The extension name ``"orcapod.uuid"`` is Orcapod-owned and stable, replacing + the previous ``"uuid.UUID"`` name that mirrored PyArrow's ``"arrow.uuid"`` + territory. Use ``orcapod.UUID`` (a top-level alias for ``uuid.UUID``) as the + preferred way to reference this type in user code. + + The storage type is ``pa.large_binary()`` (variable-length binary), using + big-endian byte order as returned by ``uuid.UUID.bytes``. ``large_binary`` + is used rather than ``pa.binary(16)`` (fixed-size) because Polars maps + fixed-size binary to variable-length on the round-trip, which would + conflict with the deserializer's storage type check. + + Example: + >>> import uuid + >>> lt = LogicalUUID() + >>> u = uuid.uuid4() + >>> lt.storage_to_python(lt.python_to_storage(u)) == u + True + """ + + _arrow_ext_class = make_arrow_extension_type("orcapod.uuid", pa.large_binary()) + _arrow_ext: pa.ExtensionType | None = None + _polars_ext_class = make_polars_extension_type("orcapod.uuid", pa.large_binary()) + _polars_ext: pl.BaseExtension | None = None + + logical_type_name: str = "orcapod.uuid" + python_type: type = _uuid_module.UUID + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for ``uuid.UUID``. + + Returns: + A cached ``pa.ExtensionType`` instance with extension name + ``"orcapod.uuid"`` and storage type ``pa.large_binary()``. + """ + if LogicalUUID._arrow_ext is None: + LogicalUUID._arrow_ext = LogicalUUID._arrow_ext_class() + return LogicalUUID._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for ``uuid.UUID``. + + Returns: + A cached ``pl.BaseExtension`` instance registered under + ``"orcapod.uuid"``. + """ + if LogicalUUID._polars_ext is None: + LogicalUUID._polars_ext = LogicalUUID._polars_ext_class() + return LogicalUUID._polars_ext + + def python_to_storage(self, value: Any) -> bytes: + """Convert a ``uuid.UUID`` to its 16-byte binary representation. + + Args: + value: A ``uuid.UUID`` instance. + + Returns: + A 16-byte ``bytes`` object (big-endian byte order, as per + ``uuid.UUID.bytes``). + """ + return value.bytes + + def storage_to_python(self, storage_value: Any) -> _uuid_module.UUID: + """Reconstruct a ``uuid.UUID`` from its 16-byte binary representation. + + Args: + storage_value: A bytes-like object of length 16. + + Returns: + A ``uuid.UUID`` instance. + """ + return _uuid_module.UUID(bytes=bytes(storage_value)) +``` + +- [ ] **Step 5: Run the failing tests to confirm they are now green** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v --tb=short 2>&1 | tail -20 +``` + +Expected: all tests pass. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/builtin_logical_types.py \ + tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(extension_types): rename built-in extension types to orcapod.* namespace + +LogicalPath: 'pathlib.Path' -> 'orcapod.path' +LogicalUPath: 'upath.UPath' -> 'orcapod.upath' +LogicalUUID: 'uuid.UUID' -> 'orcapod.uuid' + +Orcapod now owns the canonical extension identity for all three built-in +types, decoupling on-disk names from upstream library module paths." +``` + +--- + +## Task 3: Add tests for top-level `orcapod.Path`, `orcapod.UPath`, `orcapod.UUID` aliases + +TDD red step for the alias feature. These tests will fail until Task 4 adds the aliases. + +**Files:** +- Modify: `tests/test_extension_types/test_builtin_logical_types.py` + +- [ ] **Step 1: Append the alias test block at the end of the test file** + +Add the following to the end of `tests/test_extension_types/test_builtin_logical_types.py`: + +```python +# --------------------------------------------------------------------------- +# Top-level orcapod namespace alias tests +# --------------------------------------------------------------------------- + + +def test_orcapod_path_alias_is_pathlib_path(): + """orcapod.Path is the same object as pathlib.Path.""" + import pathlib + + import orcapod + + assert orcapod.Path is pathlib.Path + + +def test_orcapod_upath_alias_is_upath_upath(): + """orcapod.UPath is the same object as upath.UPath.""" + from upath import UPath + + import orcapod + + assert orcapod.UPath is UPath + + +def test_orcapod_uuid_alias_is_uuid_uuid(): + """orcapod.UUID is the same object as uuid.UUID.""" + import uuid + + import orcapod + + assert orcapod.UUID is uuid.UUID + + +def test_orcapod_path_alias_in_all(): + """orcapod.Path appears in orcapod.__all__.""" + import orcapod + + assert "Path" in orcapod.__all__ + + +def test_orcapod_upath_alias_in_all(): + """orcapod.UPath appears in orcapod.__all__.""" + import orcapod + + assert "UPath" in orcapod.__all__ + + +def test_orcapod_uuid_alias_in_all(): + """orcapod.UUID appears in orcapod.__all__.""" + import orcapod + + assert "UUID" in orcapod.__all__ +``` + +- [ ] **Step 2: Run the new tests to confirm they are red** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v -k "alias" --tb=short 2>&1 +``` + +Expected: 6 tests fail with `AttributeError: module 'orcapod' has no attribute 'Path'` (or similar). + +--- + +## Task 4: Add `Path`, `UPath`, `UUID` aliases to `src/orcapod/__init__.py` + +**Files:** +- Modify: `src/orcapod/__init__.py` + +- [ ] **Step 1: Add the alias imports and `__all__` entries** + +Replace the entire content of `src/orcapod/__init__.py` with: + +```python +from .config import ( + DEFAULT_CONFIG, + DisplayConfig, + HashingConfig, + OrcapodConfig, + load_config, +) +from .core.function_pod import ( + FunctionPod, + function_pod, +) +from .core.nodes.source_node import SourceNode +from .pipeline import Pipeline, PipelineJob +from .semantic_types.dataclass_encoding import register_dataclass + +# Subpackage re-exports for clean public API +from . import databases # noqa: F401 +from . import nodes # noqa: F401 +from . import operators # noqa: F401 +from . import sources # noqa: F401 +from . import streams # noqa: F401 +from . import types # noqa: F401 + +# Stable type aliases — preferred over importing directly from pathlib/upath/uuid. +# +# These aliases are the recommended way to reference these types in orcapod user code. +# Even if an upstream library is renamed or restructured, these symbols remain stable +# at ``orcapod.Path``, ``orcapod.UPath``, and ``orcapod.UUID``. Their Arrow extension +# types are registered under the ``orcapod.*`` namespace (``"orcapod.path"``, +# ``"orcapod.upath"``, ``"orcapod.uuid"``), so on-disk identity is also decoupled +# from upstream module paths. +from pathlib import Path +from upath import UPath +from uuid import UUID + +__all__ = [ + "DEFAULT_CONFIG", + "DisplayConfig", + "HashingConfig", + "OrcapodConfig", + "load_config", + "FunctionPod", + "function_pod", + "Pipeline", + "PipelineJob", + "SourceNode", + "register_dataclass", + "databases", + "nodes", + "operators", + "sources", + "streams", + "types", + # Stable type aliases + "Path", + "UPath", + "UUID", +] +``` + +- [ ] **Step 2: Run the alias tests to confirm they are now green** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v -k "alias" --tb=short 2>&1 +``` + +Expected: all 6 alias tests pass. + +- [ ] **Step 3: Run the full builtin logical types test suite** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v --tb=short 2>&1 | tail -20 +``` + +Expected: all tests pass (the full suite, not just alias tests). + +- [ ] **Step 4: Run the broader extension_types test suite to check for regressions** + +```bash +uv run pytest tests/test_extension_types/ -v --tb=short 2>&1 | tail -30 +``` + +Expected: all tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/__init__.py \ + tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(orcapod): expose Path, UPath, UUID as stable top-level aliases + +Adds orcapod.Path, orcapod.UPath, orcapod.UUID as re-exports of +pathlib.Path, upath.UPath, and uuid.UUID respectively. These are the +preferred symbols for user code — stable even if upstream libraries +rename their types or module paths." +``` + +--- + +## Task 5: Final verification — full test suite + +- [ ] **Step 1: Run the complete test suite** + +```bash +uv run pytest tests/ -x --tb=short 2>&1 | tail -40 +``` + +Expected: all tests pass (no regressions in any other test module). + +- [ ] **Step 2: Verify the branch is clean and ready for PR** + +```bash +git status +git log --oneline origin/extension-type-system..HEAD +``` + +Expected: 2 commits ahead of `extension-type-system`, working tree clean. + +--- + +## Self-Review Checklist + +**Spec coverage:** + +| Requirement | Task that covers it | +|-------------|-------------------| +| `LogicalPath` registers under `"orcapod.path"` | Task 2 Step 2 | +| `LogicalUPath` registers under `"orcapod.upath"` | Task 2 Step 3 | +| `LogicalUUID` registers under `"orcapod.uuid"` | Task 2 Step 4 | +| `orcapod.uuid` no longer conflicts with `arrow.uuid` | Task 2 Step 4 (new name `"orcapod.uuid"` vs PyArrow's `"arrow.uuid"`) | +| `orcapod.Path` alias exposed at top-level | Task 4 Step 1 | +| `orcapod.UPath` alias exposed at top-level | Task 4 Step 1 | +| `orcapod.UUID` alias exposed at top-level | Task 4 Step 1 | +| Aliases documented as preferred + stability rationale | Task 4 Step 1 (comment block) | +| Stability rationale in module docstring | Task 2 Step 1 | +| Existing round-trip behavior continues to work | Task 5 Step 1 | +| Unit tests updated to assert `orcapod.*` names | Task 1 + Task 3 | + +**No placeholders:** All steps contain exact code. No "TBD" or "similar to above" references. + +**Type consistency:** `logical_type_name` constants and `extension_name` strings are consistent across Tasks 1, 2, 3, and 4 — `"orcapod.path"`, `"orcapod.upath"`, `"orcapod.uuid"` throughout. diff --git a/superpowers/plans/2026-06-15-plt-1672-write-side-logical-type-factory.md b/superpowers/plans/2026-06-15-plt-1672-write-side-logical-type-factory.md new file mode 100644 index 00000000..931e78c7 --- /dev/null +++ b/superpowers/plans/2026-06-15-plt-1672-write-side-logical-type-factory.md @@ -0,0 +1,1482 @@ +# PLT-1672: Write-Side Logical Type Factory Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add Python-class-keyed write-side factory dispatch to `LogicalTypeRegistry` and wire it into `UniversalTypeConverter` and `_FunctionPodBase` so that unregistered Python types are auto-registered via a factory at function pod declaration time. + +**Architecture:** Two new factory dispatch axes (category-keyed for reads, python-class-keyed for writes) are unified in `LogicalTypeRegistry`'s `ensure_logical_type_for_python_class` with a shared MRO resolution algorithm. A recursive `_extract_leaf_classes` unwrapper in a new `type_utils.py` feeds the write-side trigger in `_FunctionPodBase.__init__`. `UniversalTypeConverter` is extended with a one-line priority check so registered extension types take precedence over the old shape-based system at encoding time. + +**Tech Stack:** Python 3.12+, PyArrow, Polars, `typing.get_origin`/`get_args` for generic annotation unwrapping. All tests via `uv run pytest`. + +--- + +## File Map + +| File | Action | Responsibility | +|---|---|---| +| `src/orcapod/extension_types/protocols.py` | Modify | Rename `create_logical_type` → `reconstruct_from_arrow`; add `create_for_python_type` | +| `src/orcapod/extension_types/registry.py` | Modify | Rename `_factories` → `_category_factories`; add `_python_class_factories`; extend `register_logical_type_factory`; add `ensure_logical_type_for_python_class` | +| `src/orcapod/extension_types/type_utils.py` | Create | `_extract_leaf_classes(annotation)` — recursive generic annotation unwrapper | +| `src/orcapod/extension_types/__init__.py` | Modify | Export `_extract_leaf_classes` | +| `src/orcapod/semantic_types/universal_converter.py` | Modify | Add `_logical_type_registry` attribute; insert priority check before `semantic_registry` in `_convert_python_to_arrow` | +| `src/orcapod/contexts/core.py` | Modify | Add `DataContext.__post_init__` to wire `logical_type_registry` into `type_converter` | +| `src/orcapod/core/function_pod.py` | Modify | Add `_ARROW_NATIVE_TYPES`, `_trigger_write_side_registration`; call from `_FunctionPodBase.__init__` | +| `tests/test_extension_types/test_protocols.py` | Modify | Update `_StubFactory` stub; add `create_for_python_type` conformance test | +| `tests/test_extension_types/test_registry.py` | Modify | Update all `register_logical_type_factory` call sites; add `ensure_logical_type_for_python_class` tests | +| `tests/test_extension_types/test_type_utils.py` | Create | Tests for `_extract_leaf_classes` | +| `tests/test_semantic_types/test_universal_converter.py` | Modify | Add `_logical_type_registry` priority check tests | +| `tests/test_core/function_pod/test_write_side_registration.py` | Create | End-to-end pod-declaration trigger tests | + +--- + +## Task 1: Rename `create_logical_type` → `reconstruct_from_arrow` in `LogicalTypeFactoryProtocol` + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py` +- Modify: `src/orcapod/extension_types/registry.py` (call site) +- Modify: `tests/test_extension_types/test_protocols.py` +- Modify: `tests/test_extension_types/test_registry.py` (all uses of `create_logical_type`) + +- [ ] **Step 1: Update `_StubFactory` in test_protocols.py to use the new name** + +Edit `tests/test_extension_types/test_protocols.py`. Replace the `_StubFactory` class body: + +```python +class _StubFactory: + """Minimal conforming implementation of LogicalTypeFactoryProtocol for use in tests.""" + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata): + return _StubLogicalType() +``` + +Also update `test_logical_type_factory_create_returns_logical_type` to call `reconstruct_from_arrow`: + +```python +def test_logical_type_factory_create_returns_logical_type(): + """A conforming factory returns a LogicalTypeProtocol from reconstruct_from_arrow.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + result = factory.reconstruct_from_arrow( + "test.ext", pa.large_utf8(), {"category": "Test"} + ) + assert isinstance(result, LogicalTypeProtocol) +``` + +- [ ] **Step 2: Run the conformance test to confirm it fails (Protocol still expects `create_logical_type`)** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py::test_logical_type_factory_conforming_class_satisfies_protocol -v +``` + +Expected: FAIL — `_StubFactory` is no longer recognized as `LogicalTypeFactoryProtocol` because it lacks `create_logical_type`. + +- [ ] **Step 3: Rename the method in `LogicalTypeFactoryProtocol`** + +In `src/orcapod/extension_types/protocols.py`, rename `create_logical_type` to `reconstruct_from_arrow` in the `LogicalTypeFactoryProtocol` class. The full updated method: + +```python + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + ) -> LogicalTypeProtocol: + """Reconstruct a LogicalType from Arrow schema metadata (read path). + + Called by the registry when a schema walk encounters an extension type + whose metadata ``"category"`` value matches this factory's registered + category. All Arrow schema information is already known. + + Args: + arrow_extension_name: The Arrow extension type name from the schema. + storage_type: The underlying Arrow storage type. + metadata: Full parsed metadata JSON dict. Always contains ``"category"``. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot reconstruct a type for the given name. + """ + ... +``` + +- [ ] **Step 4: Update the call site in `registry.py`** + +In `src/orcapod/extension_types/registry.py`, find `ensure_extension_type`. Replace: + +```python + logical_type = factory.create_logical_type( + arrow_extension_name, storage_type, metadata_dict + ) +``` + +with: + +```python + logical_type = factory.reconstruct_from_arrow( + arrow_extension_name, storage_type, metadata_dict + ) +``` + +- [ ] **Step 5: Update `_make_stub_factory` in `test_registry.py`** + +In `tests/test_extension_types/test_registry.py`, find `_make_stub_factory`. Replace `create_logical_type` with `reconstruct_from_arrow` in the inner `_Factory` class: + +```python + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata): + self.calls.append((arrow_extension_name, storage_type, metadata)) + if _return_lt is not None: + return _return_lt + return _make_stub(arrow_name=arrow_extension_name, storage=storage_type) +``` + +- [ ] **Step 6: Run the full test suite for extension_types to confirm all pass** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: All previously passing tests still pass. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py \ + src/orcapod/extension_types/registry.py \ + tests/test_extension_types/test_protocols.py \ + tests/test_extension_types/test_registry.py +git commit -m "refactor(extension_types): rename create_logical_type to reconstruct_from_arrow in LogicalTypeFactoryProtocol" +``` + +--- + +## Task 2: Add `create_for_python_type` to `LogicalTypeFactoryProtocol` + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py` +- Modify: `tests/test_extension_types/test_protocols.py` + +- [ ] **Step 1: Write the failing conformance test** + +Add to `tests/test_extension_types/test_protocols.py`. First update `_StubFactory` to add the new method: + +```python +class _StubFactory: + """Minimal conforming implementation of LogicalTypeFactoryProtocol for use in tests.""" + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata): + return _StubLogicalType() + + def create_for_python_type(self, python_type): + return _StubLogicalType() +``` + +Then add the test: + +```python +def test_factory_create_for_python_type_conformance(): + """A conforming factory implements create_for_python_type and returns LogicalTypeProtocol.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + assert isinstance(factory, LogicalTypeFactoryProtocol) + result = factory.create_for_python_type(str) + assert isinstance(result, LogicalTypeProtocol) +``` + +- [ ] **Step 2: Run to confirm it fails (Protocol does not yet require `create_for_python_type`)** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py::test_factory_create_for_python_type_conformance -v +``` + +Expected: FAIL — `LogicalTypeFactoryProtocol` does not yet define `create_for_python_type`, so the `isinstance` check passes but calling an undefined method would fail; or the test passes vacuously — either way, add the method to the Protocol so it becomes structurally required. + +- [ ] **Step 3: Add `create_for_python_type` to `LogicalTypeFactoryProtocol` in `protocols.py`** + +```python + def create_for_python_type( + self, + python_type: type, + ) -> LogicalTypeProtocol: + """Synthesize a LogicalType for the given Python class (write path). + + Called by the registry when pod declaration encounters an unregistered + class whose MRO intersects this factory's registered ``python_bases``. + The factory derives all Arrow metadata (extension name, storage type, + metadata dict) from the Python class itself. + + The returned LogicalType must round-trip: the extension name and metadata + it produces must route back to this same factory's ``reconstruct_from_arrow`` + on a subsequent read. + + Args: + python_type: The concrete Python class to synthesize a LogicalType for. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot construct a type for the given class. + """ + ... +``` + +- [ ] **Step 4: Run to confirm the test passes** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v +``` + +Expected: All pass. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py \ + tests/test_extension_types/test_protocols.py +git commit -m "feat(extension_types): add create_for_python_type to LogicalTypeFactoryProtocol" +``` + +--- + +## Task 3: Extend `LogicalTypeRegistry` internals and `register_logical_type_factory` + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Modify: `tests/test_extension_types/test_registry.py` + +This task renames `_factories` → `_category_factories`, adds `_python_class_factories`, changes the `register_logical_type_factory` signature, and updates all existing call sites. + +- [ ] **Step 1: Write the new `register_logical_type_factory` tests** + +Add to `tests/test_extension_types/test_registry.py`: + +```python +# ── register_logical_type_factory extended API ─────────────────────────────── + +def test_register_logical_type_factory_keyword_category(): + """register_logical_type_factory accepts factory as first arg, category as keyword.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="TestCat") # no error + + +def test_register_logical_type_factory_keyword_python_bases(): + """register_logical_type_factory accepts python_bases as keyword.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[str]) # no error + + +def test_register_logical_type_factory_both_axes(): + """register_logical_type_factory accepts both category and python_bases.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="Cat", python_bases=[str, int]) + + +def test_register_logical_type_factory_no_axes_raises(): + """register_logical_type_factory raises ValueError when called with no axes.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + with pytest.raises(ValueError, match="At least one of"): + registry.register_logical_type_factory(factory) + + +def test_register_logical_type_factory_python_base_duplicate_different_factory_raises(): + """Registering a different factory for the same python_base raises ValueError.""" + registry = LogicalTypeRegistry() + f1 = _make_stub_factory() + f2 = _make_stub_factory() + registry.register_logical_type_factory(f1, python_bases=[str]) + with pytest.raises(ValueError): + registry.register_logical_type_factory(f2, python_bases=[str]) + + +def test_register_logical_type_factory_python_base_same_factory_idempotent(): + """Registering the same factory twice for the same python_base is a no-op.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[str]) + registry.register_logical_type_factory(factory, python_bases=[str]) # no error +``` + +- [ ] **Step 2: Run to confirm new tests fail** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -k "keyword_category or keyword_python_bases or both_axes or no_axes or python_base" -v +``` + +Expected: FAIL — `register_logical_type_factory` currently takes `(category, factory)` positionally. + +- [ ] **Step 3: Update existing `register_logical_type_factory` call sites in test_registry.py** + +Search for all existing calls to `register_logical_type_factory` that use the old positional signature `(category, factory)` and update them to the new keyword form `(factory, category=...)`. + +Run this to find them: +```bash +grep -n "register_logical_type_factory" tests/test_extension_types/test_registry.py +``` + +For each occurrence of the form `registry.register_logical_type_factory("SomeCategory", factory)`, replace with `registry.register_logical_type_factory(factory, category="SomeCategory")`. + +- [ ] **Step 4: Update `_make_stub_factory` to also add `create_for_python_type`** + +In `test_registry.py`, update `_make_stub_factory` so the inner `_Factory` class also implements `create_for_python_type` (required by the updated `LogicalTypeFactoryProtocol`): + +```python +def _make_stub_factory(return_lt: LogicalTypeProtocol | None = None) -> LogicalTypeFactoryProtocol: + """Factory for minimal LogicalTypeFactoryProtocol conforming stubs.""" + _return_lt = return_lt + + class _Factory: + def __init__(self): + self.calls: list[tuple] = [] + self.python_type_calls: list[type] = [] + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata): + self.calls.append((arrow_extension_name, storage_type, metadata)) + if _return_lt is not None: + return _return_lt + return _make_stub(arrow_name=arrow_extension_name, storage=storage_type) + + def create_for_python_type(self, python_type): + self.python_type_calls.append(python_type) + if _return_lt is not None: + return _return_lt + return _make_stub(py_type=python_type) + + return _Factory() +``` + +- [ ] **Step 5: Implement the changes in `registry.py`** + +In `LogicalTypeRegistry.__init__`, rename `_factories` → `_category_factories` and add `_python_class_factories`: + +```python + def __init__(self, logical_types: list[LogicalTypeProtocol] | None = None) -> None: + self._by_logical_name: dict[str, LogicalTypeProtocol] = {} + self._by_arrow_name: dict[str, LogicalTypeProtocol] = {} + self._by_python_type: dict[type, LogicalTypeProtocol] = {} + self._category_factories: dict[str, LogicalTypeFactoryProtocol] = {} + self._python_class_factories: dict[type, LogicalTypeFactoryProtocol] = {} + for lt in (logical_types or []): + self.register_logical_type(lt) +``` + +Replace `register_logical_type_factory` with the new signature. Find the existing method and replace it entirely: + +```python + def register_logical_type_factory( + self, + factory: LogicalTypeFactoryProtocol, + *, + category: str | None = None, + python_bases: Iterable[type] = (), + ) -> None: + """Register a factory on one or both dispatch axes. + + Args: + factory: The factory to register. + category: If given, registers factory as the read-side handler for Arrow + extension types whose metadata contains this category string. Raises + ``ValueError`` if a different factory is already registered for this + category. + python_bases: Zero or more Python base classes. Registers factory as the + write-side handler for each. Raises ``ValueError`` if a different + factory is already registered for a given base. + + Raises: + ValueError: If neither ``category`` nor ``python_bases`` is provided. + ValueError: If a different factory is already registered for a given key. + """ + if category is None and not python_bases: + raise ValueError( + "At least one of 'category' or 'python_bases' must be provided." + ) + if category is not None: + existing = self._category_factories.get(category) + if existing is not None and existing is not factory: + raise ValueError( + f"Cannot register factory for category {category!r}: " + f"a different factory is already registered for this category." + ) + if existing is not factory: + self._category_factories[category] = factory + logger.debug( + "registered LogicalTypeFactory for category %r: %r", category, factory + ) + for base in python_bases: + existing = self._python_class_factories.get(base) + if existing is not None and existing is not factory: + raise ValueError( + f"Cannot register factory for python base {base!r}: " + f"a different factory is already registered for this base." + ) + if existing is not factory: + self._python_class_factories[base] = factory + logger.debug( + "registered LogicalTypeFactory for python base %r: %r", base, factory + ) +``` + +Also update the `ensure_extension_type` method: replace any reference to `self._factories` with `self._category_factories`. + +- [ ] **Step 6: Run all registry tests to confirm they pass** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: All pass. (Any test using the old positional signature was updated in Step 3.) + +- [ ] **Step 7: Run the full extension_types test suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: All pass. + +- [ ] **Step 8: Commit** + +```bash +git add src/orcapod/extension_types/registry.py \ + tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add python_class_factories axis to LogicalTypeRegistry; extend register_logical_type_factory" +``` + +--- + +## Task 4: Add `ensure_logical_type_for_python_class` to `LogicalTypeRegistry` + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py` +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Write all failing tests for `ensure_logical_type_for_python_class`** + +Add this block to `tests/test_extension_types/test_registry.py`: + +```python +# ── ensure_logical_type_for_python_class tests ─────────────────────────────── + +class _A: + pass + + +class _B(_A): + pass + + +class _C(_B): + pass + + +def test_ensure_for_python_class_concrete_exact_match(): + """Returns the concrete LogicalType when exact Python type is registered.""" + registry = LogicalTypeRegistry() + lt = _make_stub(py_type=_A) + registry.register_logical_type(lt) + result = registry.ensure_logical_type_for_python_class(_A) + assert result is lt + + +def test_ensure_for_python_class_concrete_mro_match(): + """Returns concrete LogicalType registered for a parent class via MRO walk.""" + registry = LogicalTypeRegistry() + lt = _make_stub(py_type=_A) + registry.register_logical_type(lt) + result = registry.ensure_logical_type_for_python_class(_C) + assert result is lt + + +def test_ensure_for_python_class_factory_synthesis(): + """Calls factory.create_for_python_type and registers the result.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[_A]) + result = registry.ensure_logical_type_for_python_class(_C) + assert len(factory.python_type_calls) == 1 + assert factory.python_type_calls[0] is _C + # Synthesized type is now registered — second call hits cache + cached = registry.ensure_logical_type_for_python_class(_C) + assert cached is result + assert len(factory.python_type_calls) == 1 # factory NOT called again + + +def test_ensure_for_python_class_concrete_beats_factory_same_mro_level(): + """When concrete type and factory are registered for the same class, concrete wins.""" + registry = LogicalTypeRegistry() + lt = _make_stub(py_type=_A) + registry.register_logical_type(lt) + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[_A]) + result = registry.ensure_logical_type_for_python_class(_A) + assert result is lt + assert len(factory.python_type_calls) == 0 # factory never called + + +def test_ensure_for_python_class_factory_more_specific_than_concrete(): + """Factory registered for a subclass beats concrete registered for a parent.""" + registry = LogicalTypeRegistry() + lt_a = _make_stub(py_type=_A) + registry.register_logical_type(lt_a) # concrete for _A + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[_B]) # factory for _B + # Query _C: factory at _B (MRO index 1) beats concrete at _A (MRO index 2) + result = registry.ensure_logical_type_for_python_class(_C) + assert len(factory.python_type_calls) == 1 + assert factory.python_type_calls[0] is _C + + +def test_ensure_for_python_class_concrete_more_specific_than_factory(): + """Concrete registered for a subclass beats factory registered for a parent.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[_A]) # factory for _A + lt_b = _make_stub(py_type=_B) + registry.register_logical_type(lt_b) # concrete for _B + # Query _C: concrete at _B (MRO index 1) beats factory at _A (MRO index 2) + result = registry.ensure_logical_type_for_python_class(_C) + assert result is lt_b + assert len(factory.python_type_calls) == 0 + + +def test_ensure_for_python_class_abc_subclasshook(): + """issubclass fallback scan catches ABCs with __subclasshook__.""" + from abc import ABCMeta + + class _StructuralABC(metaclass=ABCMeta): + @classmethod + def __subclasshook__(cls, C): + return hasattr(C, "_MARKER") + + class _MarkedClass: + _MARKER = True + + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[_StructuralABC]) + result = registry.ensure_logical_type_for_python_class(_MarkedClass) + assert len(factory.python_type_calls) == 1 + assert factory.python_type_calls[0] is _MarkedClass + + +def test_ensure_for_python_class_no_match_raises_type_error(): + """TypeError raised when no LogicalType and no factory match the type.""" + registry = LogicalTypeRegistry() + + with pytest.raises(TypeError, match="No LogicalType or LogicalTypeFactory"): + registry.ensure_logical_type_for_python_class(_C) +``` + +- [ ] **Step 2: Run to confirm all fail** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -k "ensure_for_python_class" -v +``` + +Expected: All FAIL with `AttributeError: 'LogicalTypeRegistry' has no attribute 'ensure_logical_type_for_python_class'`. + +- [ ] **Step 3: Implement `ensure_logical_type_for_python_class` in `registry.py`** + +Add the method to `LogicalTypeRegistry` after `ensure_extension_type`: + +```python + def ensure_logical_type_for_python_class( + self, + python_type: type, + ) -> LogicalTypeProtocol: + """Ensure a LogicalType exists for python_type, synthesizing via factory if needed. + + Resolution algorithm: + 1. Walk ``python_type.__mro__``. Track the first (most-specific) hit in + ``_by_python_type`` (concrete) and ``_python_class_factories`` (factory) + separately, recording the MRO index of each. + 2. After the MRO walk, if no factory was found, do a fallback ``issubclass`` + scan over ``_python_class_factories`` keys to catch ABCs with + ``__subclasshook__``. Assign these the least-specific MRO index + (len of __mro__) so they lose to any direct MRO match. + 3. Resolution rule: if both concrete and factory found, compare MRO indices — + lower index wins. Ties (same class) → concrete wins. + 4. If factory wins (or only factory found): call + ``factory.create_for_python_type(python_type)``, register the result, + and return it. The registration caches it in ``_by_python_type[python_type]``. + 5. If nothing found: raise ``TypeError``. + + Args: + python_type: The Python class to resolve. + + Returns: + The registered or newly synthesized ``LogicalTypeProtocol``. + + Raises: + TypeError: If no ``LogicalType`` and no factory is found. + """ + best_concrete_idx: int | None = None + best_concrete: LogicalTypeProtocol | None = None + best_factory_idx: int | None = None + best_factory: LogicalTypeFactoryProtocol | None = None + + # Step 1: Walk MRO + for i, base in enumerate(python_type.__mro__): + if best_concrete is None and base in self._by_python_type: + best_concrete_idx = i + best_concrete = self._by_python_type[base] + if best_factory is None and base in self._python_class_factories: + best_factory_idx = i + best_factory = self._python_class_factories[base] + if best_concrete is not None and best_factory is not None: + break + + # Step 2: issubclass fallback scan for ABCs with __subclasshook__ + if best_factory is None: + for base_class, factory in self._python_class_factories.items(): + try: + if issubclass(python_type, base_class): + best_factory = factory + # ABC match — less specific than any direct MRO hit + best_factory_idx = len(python_type.__mro__) + break + except TypeError: + continue + + # Step 3: Resolution + if best_concrete is None and best_factory is None: + raise TypeError( + f"No LogicalType or LogicalTypeFactory is registered for type " + f"{python_type!r}.\n" + f"To handle this type, register a factory for its base class:\n" + f" registry.register_logical_type_factory(\n" + f" factory, python_bases=[]\n" + f" )\n" + f"Or register a concrete LogicalType directly:\n" + f" registry.register_logical_type(my_logical_type)" + ) + + if best_factory is None: + # Only concrete found + assert best_concrete is not None + return best_concrete + + if best_concrete is None: + # Only factory found — synthesize + assert best_factory is not None + lt = best_factory.create_for_python_type(python_type) + self.register_logical_type(lt) + logger.debug( + "ensure_logical_type_for_python_class: synthesized %r for %r", + lt.logical_type_name, + python_type, + ) + return lt + + # Both found — compare specificity (lower MRO index = more specific) + assert best_concrete_idx is not None + assert best_factory_idx is not None + if best_concrete_idx <= best_factory_idx: + # Concrete is same level (ties → concrete wins) or more specific + return best_concrete + else: + # Factory is more specific — synthesize + lt = best_factory.create_for_python_type(python_type) + self.register_logical_type(lt) + logger.debug( + "ensure_logical_type_for_python_class: synthesized %r for %r via more-specific factory", + lt.logical_type_name, + python_type, + ) + return lt +``` + +- [ ] **Step 4: Run the new tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -k "ensure_for_python_class" -v +``` + +Expected: All pass. + +- [ ] **Step 5: Run the full extension_types suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: All pass. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/registry.py \ + tests/test_extension_types/test_registry.py +git commit -m "feat(extension_types): add ensure_logical_type_for_python_class with unified MRO resolution" +``` + +--- + +## Task 5: Add `_extract_leaf_classes` in `type_utils.py` + +**Files:** +- Create: `src/orcapod/extension_types/type_utils.py` +- Modify: `src/orcapod/extension_types/__init__.py` +- Create: `tests/test_extension_types/test_type_utils.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_extension_types/test_type_utils.py`: + +```python +"""Tests for extension_types.type_utils helpers.""" + +from __future__ import annotations + +from typing import Optional, Union + +from orcapod.extension_types.type_utils import _extract_leaf_classes + + +class _A: + pass + + +class _B: + pass + + +def test_plain_class(): + assert list(_extract_leaf_classes(int)) == [int] + + +def test_plain_custom_class(): + assert list(_extract_leaf_classes(_A)) == [_A] + + +def test_list_of_class(): + assert list(_extract_leaf_classes(list[int])) == [int] + + +def test_dict_of_classes(): + result = set(_extract_leaf_classes(dict[str, int])) + assert result == {str, int} + + +def test_optional_unwraps_none(): + """Optional[X] yields X but not NoneType.""" + result = list(_extract_leaf_classes(Optional[int])) + assert result == [int] + + +def test_union_yields_all_non_none(): + result = set(_extract_leaf_classes(Union[int, str])) + assert result == {int, str} + + +def test_union_with_none_excludes_none(): + result = set(_extract_leaf_classes(Union[int, None])) + assert type(None) not in result + assert int in result + + +def test_nested_list_of_dict(): + """list[dict[_A, list[_B]]] yields _A and _B.""" + result = set(_extract_leaf_classes(list[dict[_A, list[_B]]])) + assert result == {_A, _B} + + +def test_deeply_nested(): + """list[dict[str, list[dict[int, _A]]]] yields str, int, _A.""" + result = set(_extract_leaf_classes(list[dict[str, list[dict[int, _A]]]])) + assert result == {str, int, _A} + + +def test_non_generic_non_type_is_skipped(): + """Annotations that are not types and not generic aliases yield nothing.""" + # e.g. a string annotation that failed resolution — should not crash + result = list(_extract_leaf_classes("unresolved_string")) + assert result == [] + + +def test_none_type_plain(): + """type(None) itself yields type(None) as a leaf (not filtered at this level).""" + result = list(_extract_leaf_classes(type(None))) + assert result == [type(None)] +``` + +- [ ] **Step 2: Run to confirm all fail** + +```bash +uv run pytest tests/test_extension_types/test_type_utils.py -v +``` + +Expected: All FAIL with `ModuleNotFoundError` or `ImportError`. + +- [ ] **Step 3: Create `src/orcapod/extension_types/type_utils.py`** + +```python +"""Utility helpers for Python type annotation inspection. + +Used by the write-side registration trigger to extract leaf Python classes from +complex generic annotations like ``list[dict[A, list[B]]]``. +""" + +from __future__ import annotations + +import typing +from typing import Any, Iterator + + +def _extract_leaf_classes(annotation: Any) -> Iterator[type]: + """Recursively yield all concrete leaf Python classes from a type annotation. + + Unwraps generic aliases (``list[T]``, ``dict[K, V]``, ``Optional[T]``, + ``Union[A, B]``, etc.) using ``typing.get_origin`` and ``typing.get_args`` + and yields every non-generic leaf found. ``NoneType`` (from ``Optional`` + and ``Union[..., None]``) is yielded as-is — callers that want to skip it + should filter on ``type(None)``. + + Non-type, non-generic values (e.g. unresolved string annotations) are + silently skipped. + + Args: + annotation: A Python type or generic alias to inspect. + + Yields: + Concrete Python ``type`` objects found at leaf positions. + + Examples: + >>> list(_extract_leaf_classes(list[int])) + [] + >>> set(_extract_leaf_classes(dict[str, list[MyClass]])) + {, } + """ + origin = typing.get_origin(annotation) + + if origin is None: + # Not a generic alias. Yield only if it is a plain type. + if isinstance(annotation, type): + yield annotation + return + + # Generic alias — recurse into every type argument. + for arg in typing.get_args(annotation): + yield from _extract_leaf_classes(arg) +``` + +- [ ] **Step 4: Export from `__init__.py`** + +In `src/orcapod/extension_types/__init__.py`, add to the imports and `__all__`: + +```python +from .type_utils import _extract_leaf_classes +``` + +And add `"_extract_leaf_classes"` to `__all__`. + +- [ ] **Step 5: Run to confirm all tests pass** + +```bash +uv run pytest tests/test_extension_types/test_type_utils.py -v +``` + +Expected: All pass. + +- [ ] **Step 6: Run the full extension_types suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: All pass. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/extension_types/type_utils.py \ + src/orcapod/extension_types/__init__.py \ + tests/test_extension_types/test_type_utils.py +git commit -m "feat(extension_types): add _extract_leaf_classes for recursive generic annotation unwrapping" +``` + +--- + +## Task 6: Wire `LogicalTypeRegistry` into `UniversalTypeConverter` and `DataContext` + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Modify: `src/orcapod/contexts/core.py` +- Modify: `tests/test_semantic_types/test_universal_converter.py` + +- [ ] **Step 1: Write the failing tests** + +Add to `tests/test_semantic_types/test_universal_converter.py`: + +```python +# ── LogicalTypeRegistry priority tests ─────────────────────────────────────── + +import pyarrow as pa +import polars as pl + +from orcapod.extension_types.registry import ( + LogicalTypeRegistry, + make_arrow_extension_type, + make_polars_extension_type, +) +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +def _make_logical_type_stub(py_type: type, arrow_name: str) -> object: + """Return a minimal LogicalTypeProtocol conforming stub.""" + _ArrowExtClass = make_arrow_extension_type(arrow_name, pa.large_string()) + _pl_dtype = pl.String + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(arrow_name, _pl_dtype, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _Stub: + logical_type_name = arrow_name + python_type = py_type + + def get_arrow_extension_type(self): + return _ArrowExtClass() + + def get_polars_extension_type(self): + return _PolarsExt() + + def python_to_storage(self, value): + return str(value) + + def storage_to_python(self, storage_value): + return storage_value + + return _Stub() + + +class _MyCustomClass: + pass + + +def test_converter_uses_logical_type_registry_for_registered_type(): + """When a LogicalType is registered, converter returns its Arrow extension type.""" + import uuid as _uuid + arrow_name = f"test.MyCustomClass.{_uuid.uuid4().hex[:8]}" + lt = _make_logical_type_stub(_MyCustomClass, arrow_name) + + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + converter = UniversalTypeConverter() + converter._logical_type_registry = registry + + result = converter.python_type_to_arrow_type(_MyCustomClass) + expected_ext = lt.get_arrow_extension_type() + assert result == expected_ext + + +def test_converter_falls_through_for_unregistered_type(): + """If type not in LogicalTypeRegistry, converter falls through to old system (int → int64).""" + registry = LogicalTypeRegistry() + converter = UniversalTypeConverter() + converter._logical_type_registry = registry + + result = converter.python_type_to_arrow_type(int) + assert result == pa.int64() + + +def test_converter_without_registry_unchanged(): + """With no _logical_type_registry set, converter behaves exactly as before.""" + converter = UniversalTypeConverter() + assert converter.python_type_to_arrow_type(str) == pa.large_string() + + +def test_data_context_wires_registry_into_converter(): + """DataContext.__post_init__ wires logical_type_registry into type_converter.""" + from orcapod.contexts import get_default_context + ctx = get_default_context() + assert hasattr(ctx.type_converter, "_logical_type_registry") + assert ctx.type_converter._logical_type_registry is ctx.logical_type_registry +``` + +- [ ] **Step 2: Run to confirm tests fail** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -k "logical_type_registry or data_context_wires" -v +``` + +Expected: FAIL — `UniversalTypeConverter` has no `_logical_type_registry` attribute. + +- [ ] **Step 3: Add `_logical_type_registry` to `UniversalTypeConverter.__init__`** + +In `src/orcapod/semantic_types/universal_converter.py`, update `__init__`: + +```python + def __init__( + self, + semantic_registry: SemanticTypeRegistry | None = None, + datetime_timezone: typing.Literal["strict", "coerce_utc"] = "strict", + ): + """ + Args: + semantic_registry: Optional registry of semantic type converters. + datetime_timezone: How to handle naive (timezone-less) ``datetime`` + values when converting Python → Arrow. + + ``"strict"`` (default) — raise ``ValueError`` immediately so + callers are forced to be explicit about timezone semantics. + + ``"coerce_utc"`` — silently attach ``timezone.utc`` to naive + datetimes before writing to Arrow. Use this when you know that + all naive datetimes in your data represent UTC. + """ + self.semantic_registry = semantic_registry + self._datetime_timezone = datetime_timezone + self._logical_type_registry = None # set by DataContext.__post_init__ + # ... rest of existing __init__ unchanged ... +``` + +- [ ] **Step 4: Insert the priority check in `_convert_python_to_arrow`** + +In `src/orcapod/semantic_types/universal_converter.py`, find `_convert_python_to_arrow` (around line 411). After the `type_map` check and before the `semantic_registry` check, insert: + +```python + # Check LogicalTypeRegistry first — extension-type identity takes priority + if self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + return lt.get_arrow_extension_type() +``` + +The surrounding context should look like: + +```python + def _convert_python_to_arrow(self, python_type: DataType) -> pa.DataType: + """Core Python → Arrow type conversion logic.""" + type_map = _get_python_to_arrow_map() + if python_type in type_map: + return type_map[python_type] + + # Check LogicalTypeRegistry first — extension-type identity takes priority + if self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + return lt.get_arrow_extension_type() + + # Check semantic registry for registered types + if self.semantic_registry: + converter = self.semantic_registry.get_converter_for_python_type(python_type) + if converter: + return converter.arrow_struct_type + # ... rest unchanged ... +``` + +- [ ] **Step 5: Add `DataContext.__post_init__` in `contexts/core.py`** + +In `src/orcapod/contexts/core.py`, add a `__post_init__` method to `DataContext`: + +```python + def __post_init__(self) -> None: + """Wire components together after dataclass construction. + + Injects ``logical_type_registry`` into ``type_converter`` so that + registered ``LogicalType`` instances take priority over the old + shape-based ``semantic_registry`` at encoding time. + """ + if hasattr(self.type_converter, "_logical_type_registry"): + self.type_converter._logical_type_registry = self.logical_type_registry +``` + +- [ ] **Step 6: Run the new tests** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -k "logical_type_registry or data_context_wires" -v +``` + +Expected: All pass. + +- [ ] **Step 7: Run the full test suite to confirm no regressions** + +```bash +uv run pytest tests/ -v --tb=short -q +``` + +Expected: All previously passing tests still pass. + +- [ ] **Step 8: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py \ + src/orcapod/contexts/core.py \ + tests/test_semantic_types/test_universal_converter.py +git commit -m "feat(extension_types): wire LogicalTypeRegistry into UniversalTypeConverter and DataContext" +``` + +--- + +## Task 7: Add write-side trigger to `_FunctionPodBase` + +**Files:** +- Modify: `src/orcapod/core/function_pod.py` +- Create: `tests/test_core/function_pod/test_write_side_registration.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_core/function_pod/test_write_side_registration.py`: + +```python +"""Tests for write-side LogicalType auto-registration at function pod declaration. + +These tests verify that _FunctionPodBase.__init__ triggers factory synthesis for +any non-native Python types in the pod's input/output schemas, and raises TypeError +at declaration time when no factory is registered. +""" + +from __future__ import annotations + +import dataclasses +import pathlib +import uuid as _uuid_module +from typing import Optional + +import pyarrow as pa +import polars as pl +import pytest + +from orcapod.contexts import get_default_context +from orcapod.core.data_function import PythonDataFunction +from orcapod.core.function_pod import FunctionPod +from orcapod.extension_types.protocols import LogicalTypeProtocol +from orcapod.extension_types.registry import ( + LogicalTypeRegistry, + make_arrow_extension_type, + make_polars_extension_type, +) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _make_registry_with_factory(target_base: type) -> tuple[LogicalTypeRegistry, list]: + """Return a registry with a factory for target_base and a call log.""" + call_log: list[type] = [] + + def _make_lt(py_type: type) -> LogicalTypeProtocol: + arrow_name = f"{py_type.__module__}.{py_type.__qualname__}.{_uuid_module.uuid4().hex[:6]}" + ArrowExt = make_arrow_extension_type(arrow_name, pa.large_string()) + pl_dtype = pl.String + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(arrow_name, pl_dtype, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _LT: + logical_type_name = arrow_name + python_type = py_type + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return _PolarsExt() + def python_to_storage(self, v): return str(v) + def storage_to_python(self, v): return v + + return _LT() + + class _Factory: + def reconstruct_from_arrow(self, name, storage, meta): + return _make_lt(object) # unused in these tests + + def create_for_python_type(self, python_type): + call_log.append(python_type) + return _make_lt(python_type) + + registry = LogicalTypeRegistry() + registry.register_logical_type_factory(_Factory(), python_bases=[target_base]) + return registry, call_log + + +# ── Custom classes used in tests ───────────────────────────────────────────── + +class _MyBase: + pass + + +class _MyChild(_MyBase): + pass + + +# ── Tests ──────────────────────────────────────────────────────────────────── + +def test_pod_declaration_triggers_factory_for_unregistered_class(): + """Declaring a FunctionPod with an unregistered type causes factory synthesis.""" + registry, call_log = _make_registry_with_factory(_MyBase) + from orcapod.contexts.core import DataContext + from orcapod.contexts import get_default_context + # Build a context with our test registry + base_ctx = get_default_context() + ctx = DataContext( + context_key="test", + version="test", + description="test", + type_converter=base_ctx.type_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + logical_type_registry=registry, + ) + + def my_func(x: _MyChild) -> str: + return str(x) + + # Pod declaration should trigger factory for _MyChild + pod = FunctionPod( + func=my_func, + output_keys=["result"], + data_context=ctx, + ) + assert _MyChild in call_log + # The synthesized LogicalType is now in the registry + assert registry.get_by_python_type(_MyChild) is not None + + +def test_pod_declaration_with_nested_list_type(): + """list[_MyChild] in the schema causes factory synthesis for _MyChild.""" + registry, call_log = _make_registry_with_factory(_MyBase) + from orcapod.contexts.core import DataContext + from orcapod.contexts import get_default_context + base_ctx = get_default_context() + ctx = DataContext( + context_key="test", + version="test", + description="test", + type_converter=base_ctx.type_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + logical_type_registry=registry, + ) + + def my_func(items: list[_MyChild]) -> str: + return "" + + FunctionPod(func=my_func, output_keys=["result"], data_context=ctx) + assert _MyChild in call_log + + +def test_pod_declaration_native_types_no_factory_call(): + """Pods using only native types (int, str, etc.) never trigger factory lookup.""" + registry = LogicalTypeRegistry() + + class _NeverCalledFactory: + def reconstruct_from_arrow(self, *a): ... + def create_for_python_type(self, pt): + raise AssertionError(f"factory called for {pt!r}") + + registry.register_logical_type_factory( + _NeverCalledFactory(), python_bases=[object] + ) + from orcapod.contexts.core import DataContext + from orcapod.contexts import get_default_context + base_ctx = get_default_context() + ctx = DataContext( + context_key="test", version="test", description="test", + type_converter=base_ctx.type_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + logical_type_registry=registry, + ) + + def my_func(x: int, y: str) -> float: + return 0.0 + + # Should not raise — int, str, float are native + FunctionPod(func=my_func, output_keys=["result"], data_context=ctx) + + +def test_pod_declaration_raises_type_error_for_unhandled_class(): + """Pod with a type that has no registered factory raises TypeError at declaration.""" + registry = LogicalTypeRegistry() # empty — no factories + from orcapod.contexts.core import DataContext + from orcapod.contexts import get_default_context + base_ctx = get_default_context() + ctx = DataContext( + context_key="test", version="test", description="test", + type_converter=base_ctx.type_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + logical_type_registry=registry, + ) + + def my_func(x: _MyChild) -> str: + return "" + + with pytest.raises(TypeError, match="No LogicalType or LogicalTypeFactory"): + FunctionPod(func=my_func, output_keys=["result"], data_context=ctx) + + +def test_pod_declaration_already_registered_type_no_factory_call(): + """Pre-registered types are not passed to the factory.""" + registry, call_log = _make_registry_with_factory(_MyBase) + # Pre-register _MyChild directly + from orcapod.extension_types.registry import make_arrow_extension_type + ArrowExt = make_arrow_extension_type(f"test.MyChild.{_uuid_module.uuid4().hex[:6]}", pa.large_string()) + + class _PreLT: + logical_type_name = f"test.{_uuid_module.uuid4().hex[:6]}" + python_type = _MyChild + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): + class P(pl.BaseExtension): + def __init__(self): super().__init__(self.logical_type_name, pl.String, None) + @classmethod + def ext_from_params(cls, *a): return cls() + return P() + def python_to_storage(self, v): return str(v) + def storage_to_python(self, v): return v + + registry.register_logical_type(_PreLT()) + from orcapod.contexts.core import DataContext + from orcapod.contexts import get_default_context + base_ctx = get_default_context() + ctx = DataContext( + context_key="test", version="test", description="test", + type_converter=base_ctx.type_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + logical_type_registry=registry, + ) + + def my_func(x: _MyChild) -> str: + return "" + + FunctionPod(func=my_func, output_keys=["result"], data_context=ctx) + # Factory was NOT called — _MyChild was already registered + assert _MyChild not in call_log +``` + +- [ ] **Step 2: Run to confirm all fail** + +```bash +uv run pytest tests/test_core/function_pod/test_write_side_registration.py -v +``` + +Expected: All FAIL — the trigger does not exist yet. + +- [ ] **Step 3: Implement the trigger in `function_pod.py`** + +Add imports at the top of `src/orcapod/core/function_pod.py` (with the existing imports): + +```python +from orcapod.extension_types.type_utils import _extract_leaf_classes +from orcapod.extension_types.registry import LogicalTypeRegistry +``` + +Add the module-level constant and helper function before the `_FunctionPodBase` class definition: + +```python +# Python types that Arrow handles natively — no LogicalType registration needed. +_ARROW_NATIVE_TYPES: frozenset[type] = frozenset({ + int, float, str, bytes, bool, type(None), +}) + + +def _trigger_write_side_registration( + input_schema: Schema, + output_schema: Schema, + registry: LogicalTypeRegistry | None, +) -> None: + """Ensure a LogicalType is registered for every non-native leaf class in the schemas. + + Called once at pod declaration time. Recursively unwraps generic annotations + (``list[T]``, ``dict[K, V]``, etc.) to find leaf classes. Skips Arrow-native + types and already-registered types. Raises ``TypeError`` at declaration time + if no factory is registered for a leaf class — this is intentional. + + Args: + input_schema: The pod's input data schema (column name → Python type annotation). + output_schema: The pod's output data schema. + registry: The ``LogicalTypeRegistry`` from the pod's ``DataContext``. + If ``None``, this function is a no-op. + """ + if registry is None: + return + for schema in (input_schema, output_schema): + for annotation in schema.values(): + for leaf_class in _extract_leaf_classes(annotation): + if leaf_class in _ARROW_NATIVE_TYPES: + continue + if registry.get_by_python_type(leaf_class) is not None: + continue # already registered — O(1) cache hit + registry.ensure_logical_type_for_python_class(leaf_class) + # TypeError propagates if no factory matches — intentional hard error +``` + +In `_FunctionPodBase.__init__`, add the trigger call after `self._data_function = data_function`: + +```python + self._data_function = data_function + _trigger_write_side_registration( + data_function.input_data_schema, + data_function.output_data_schema, + self.data_context.logical_type_registry, + ) +``` + +- [ ] **Step 4: Run the new tests** + +```bash +uv run pytest tests/test_core/function_pod/test_write_side_registration.py -v +``` + +Expected: All pass. + +- [ ] **Step 5: Run the full test suite** + +```bash +uv run pytest tests/ -v --tb=short -q +``` + +Expected: All previously passing tests still pass. The trigger is a no-op for native types and already-registered built-ins (Path, UPath, UUID), so existing pod tests are unaffected. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/core/function_pod.py \ + tests/test_core/function_pod/test_write_side_registration.py +git commit -m "feat(extension_types): add write-side registration trigger in _FunctionPodBase.__init__" +``` + +--- + +## Self-Review Checklist + +**Spec coverage:** + +| Spec section | Covered by task | +|---|---| +| `reconstruct_from_arrow` rename | Task 1 | +| `create_for_python_type` new method | Task 2 | +| `_category_factories` rename, `_python_class_factories`, extended `register_logical_type_factory` | Task 3 | +| `ensure_logical_type_for_python_class` with unified MRO resolution, caching, TypeError | Task 4 | +| `_extract_leaf_classes` for complex nested annotations | Task 5 | +| `UniversalTypeConverter` priority check + `DataContext` wiring | Task 6 | +| `_trigger_write_side_registration`, `_ARROW_NATIVE_TYPES`, `_FunctionPodBase.__init__` call | Task 7 | +| Failure mode: hard TypeError at declaration time | Task 7 tests | +| Symmetry with read side (protocol contract documented) | Task 2 docstring | +| Built-in types unaffected | Task 7 tests (native types test, pre-registered test) | + +**Type consistency across tasks:** +- `reconstruct_from_arrow` defined in Task 1, used in Task 3 (factory stub) — consistent ✓ +- `create_for_python_type` defined in Task 2, tested in Task 4 (`python_type_calls`) — consistent ✓ +- `_category_factories` introduced in Task 3, referenced in `ensure_logical_type_for_python_class` Task 4 — consistent ✓ +- `_python_class_factories` introduced in Task 3, used in Task 4 — consistent ✓ +- `_extract_leaf_classes` created in Task 5, imported in Task 7 — consistent ✓ +- `_logical_type_registry` attribute name defined in Task 6, checked in Task 6's DataContext test — consistent ✓ +- `LogicalTypeRegistry` import added to `function_pod.py` in Task 7 type annotation — consistent ✓ diff --git a/superpowers/plans/2026-06-16-plt-1705-type-registration-spine-refactor.md b/superpowers/plans/2026-06-16-plt-1705-type-registration-spine-refactor.md new file mode 100644 index 00000000..1505b8ba --- /dev/null +++ b/superpowers/plans/2026-06-16-plt-1705-type-registration-spine-refactor.md @@ -0,0 +1,2897 @@ +# PLT-1705 Type Registration Spine Refactor Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Make `UniversalTypeConverter` the single re-entry point for Python ↔ Arrow type registration, move `LogicalTypeRegistry` inside the converter as a private implementation detail, and implement `DataclassHandlerFactory` on the refined architecture. + +**Architecture:** `register_python_class(annotation)` handles write-side recursive traversal; `register_storage_type(arrow_type)` handles read-side bottom-up traversal. Factories and logical types receive `converter` instead of `registry`, so all delegation flows through the converter. `DataContext.logical_type_registry` is removed entirely. + +**Tech Stack:** Python 3.12, PyArrow, Polars, `dataclasses`, `typing.get_type_hints` + +--- + +## File Map + +| File | Action | What changes | +|---|---|---| +| `src/orcapod/extension_types/protocols.py` | Modify | Add `TypeConverterProtocol`; add `supports_class` + `converter` param to factory protocol; add `converter` param to logical type protocol | +| `src/orcapod/extension_types/builtin_logical_types.py` | Modify | Add `converter` param (accept, ignore) to `python_to_storage` / `storage_to_python` | +| `src/orcapod/semantic_types/universal_converter.py` | Modify | Add `register_python_class`, `register_storage_type`, `python_to_storage`, `storage_to_python`, `register_logical_type`, `register_logical_type_factory`; update `_create_python_to_arrow_converter`/`_create_arrow_to_python_converter` to pass `converter=self`; simplify `ensure_types_registered_for_schemas`; remove `semantic_registry` usage; remove `dataclass_encoding` imports | +| `src/orcapod/extension_types/registry.py` | Modify | Remove `ensure_logical_type_for_python_class`, `ensure_extension_type` | +| `src/orcapod/extension_types/dataclass_handler.py` | **Create** | `DataclassLogicalType` + `DataclassHandlerFactory` | +| `src/orcapod/semantic_types/dataclass_encoding.py` | **Delete** | Superseded by `DataclassHandlerFactory` | +| `src/orcapod/extension_types/type_utils.py` | Modify | Rename `extract_leaf_classes` → `_extract_leaf_classes` (private) | +| `src/orcapod/extension_types/database_hooks.py` | Modify | `register_discovered_extensions` takes `converter` instead of `registry`; uses schema_walker + `converter._ensure_extension_type_info` | +| `src/orcapod/databases/extension_aware_database.py` | Modify | Takes `converter` instead of `registry`; passes `converter._registry` to `apply_extension_types` | +| `src/orcapod/contexts/core.py` | Modify | Remove `logical_type_registry` field from `DataContext` | +| `src/orcapod/contexts/__init__.py` | Modify | Remove `get_default_logical_type_registry` | +| `src/orcapod/contexts/registry.py` | Modify | Remove `"logical_type_registry"` from `required_fields`; stop passing it to `DataContext` | +| `src/orcapod/contexts/data/v0.1.json` | Modify | Remove top-level `logical_type_registry`; move registry construction inside `type_converter._config`; remove `semantic_registry` ref from `type_converter._config` | +| `src/orcapod/contexts/data/schemas/context_schema.json` | Modify | Remove `logical_type_registry` from `required` and `properties` | +| `src/orcapod/extension_types/__init__.py` | Modify | Update docstring | +| `tests/test_extension_types/test_protocols.py` | Modify | Update stubs for new signatures; add `TypeConverterProtocol` conformance test | +| `tests/test_extension_types/test_registry.py` | Modify | Remove `ensure_*` tests; add converter pass-through tests | +| `tests/test_extension_types/test_builtin_logical_types.py` | Modify | Pass a stub converter to `python_to_storage` / `storage_to_python` calls | +| `tests/test_extension_types/test_dataclass_handler.py` | **Create** | Full unit tests for `DataclassLogicalType` and `DataclassHandlerFactory` | +| `tests/test_semantic_types/test_universal_converter.py` | Modify | Add `register_python_class` and `register_storage_type` tests | +| `tests/test_extension_types/test_database_hooks.py` | Modify | Switch from registry to converter | +| `tests/test_core/function_pod/test_write_side_registration.py` | Modify | Update `DataContext` construction (no `logical_type_registry`) | + +--- + +## Task 1: Update `TypeConverterProtocol` and factory/logical-type protocols + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py` +- Modify: `tests/test_extension_types/test_protocols.py` + +- [ ] **Step 1: Write failing protocol conformance tests** + +Add to `tests/test_extension_types/test_protocols.py`: + +```python +# Add at the top of the file: +# from orcapod.extension_types.protocols import TypeConverterProtocol + +def test_type_converter_protocol_is_importable(): + from orcapod.extension_types.protocols import TypeConverterProtocol + assert TypeConverterProtocol is not None + + +def test_factory_supports_class_method_required(): + """LogicalTypeFactoryProtocol requires supports_class.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + + class _BadFactory: + def reconstruct_from_arrow(self, name, storage_type, metadata, converter): + pass + def create_for_python_type(self, python_type, converter): + pass + # Missing supports_class + + assert not isinstance(_BadFactory(), LogicalTypeFactoryProtocol) + + +def test_factory_with_supports_class_satisfies_protocol(): + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + + class _GoodFactory: + def supports_class(self, python_type): + return True + def reconstruct_from_arrow(self, name, storage_type, metadata, converter): + pass + def create_for_python_type(self, python_type, converter): + pass + + assert isinstance(_GoodFactory(), LogicalTypeFactoryProtocol) + + +def test_logical_type_python_to_storage_accepts_converter(): + """LogicalTypeProtocol.python_to_storage now requires converter param.""" + from orcapod.extension_types.protocols import LogicalTypeProtocol + + class _GoodLT: + @property + def logical_type_name(self): return "test.lt" + @property + def python_type(self): return str + def get_arrow_extension_type(self): pass + def get_polars_extension_type(self): pass + def python_to_storage(self, value, converter): return value + def storage_to_python(self, storage_value, converter): return storage_value + + assert isinstance(_GoodLT(), LogicalTypeProtocol) +``` + +- [ ] **Step 2: Run tests to confirm failures** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v -k "type_converter or supports_class or accepts_converter" 2>&1 | tail -30 +``` +Expected: ImportError or AttributeError failures. + +- [ ] **Step 3: Update `protocols.py`** + +Replace the entire file: + +```python +"""Protocol definitions for the Arrow/Polars extension type system. + +This module defines ``TypeConverterProtocol``, ``LogicalTypeProtocol``, and +``LogicalTypeFactoryProtocol`` — the contracts for the converter, for logical +type implementations that bind a Python class to its Arrow and Polars extension +type representation, and for factories that auto-construct such implementations +from Arrow schema metadata. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + + +@runtime_checkable +class TypeConverterProtocol(Protocol): + """Minimal protocol exposing what factories and logical types need from the converter. + + Placed in ``extension_types/protocols.py`` to avoid circular imports. + ``UniversalTypeConverter`` is the canonical implementation. + """ + + def register_python_class(self, annotation: Any) -> "pa.DataType": + """Traverse a Python annotation and return its Arrow type, registering as needed.""" + ... + + def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Traverse an Arrow type bottom-up, registering extension types, and return resolved type.""" + ... + + def python_to_storage(self, value: Any, annotation: Any) -> Any: + """Convert a Python value to its Arrow storage representation.""" + ... + + def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: + """Convert an Arrow storage value back to a Python object.""" + ... + + +@runtime_checkable +class LogicalTypeProtocol(Protocol): + """Protocol for Arrow/Polars extension-type-backed logical types. + + A ``LogicalTypeProtocol`` is a three-way binding between a unique logical type name + (orcapod's identifier), a Python class, and Arrow/Polars extension types. + Each implementation *owns* its Arrow and Polars extension types by providing + them directly via ``get_arrow_extension_type`` and ``get_polars_extension_type``. + + This protocol is Arrow I/O only — hashing is not a logical type responsibility. + """ + + @property + def logical_type_name(self) -> str: + """Unique orcapod identifier for this logical type (e.g. ``"orcapod.uuid"``).""" + ... + + @property + def python_type(self) -> type: + """The Python class this logical type represents.""" + ... + + def get_arrow_extension_type(self) -> "pa.ExtensionType": + """Return the Arrow extension type for this logical type.""" + ... + + def get_polars_extension_type(self) -> "pl.BaseExtension": + """Return an instance of the Polars extension type for this logical type.""" + ... + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol) -> Any: + """Convert a Python value to its Arrow storage representation. + + Args: + value: A Python object of type ``python_type``. + converter: The active ``TypeConverterProtocol`` for recursive delegation. + + Returns: + A value suitable for Arrow storage. + """ + ... + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol) -> Any: + """Convert an Arrow storage value back to a Python object. + + Args: + storage_value: A scalar or array element from the Arrow storage array. + converter: The active ``TypeConverterProtocol`` for recursive delegation. + + Returns: + A Python object of type ``python_type``. + """ + ... + + +@runtime_checkable +class LogicalTypeFactoryProtocol(Protocol): + """Protocol for factories that synthesize or reconstruct ``LogicalTypeProtocol`` instances. + + Bridges two directions: the write path (``create_for_python_type``) and the read + path (``reconstruct_from_arrow``). Both methods receive ``converter`` instead of + ``registry`` so all traversal flows through the converter. + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if this factory can synthesize a LogicalType for ``python_type``. + + Used as a probe during write-side MRO dispatch in ``register_python_class``. + + Args: + python_type: The Python class to probe. + + Returns: + True if this factory handles ``python_type``. + """ + ... + + def create_for_python_type( + self, + python_type: type, + converter: TypeConverterProtocol, + ) -> LogicalTypeProtocol: + """Synthesize a LogicalType for the given Python class (write path). + + Args: + python_type: The concrete Python class to synthesize a LogicalType for. + converter: The active converter for recursive field-type resolution. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot construct a type for the given class. + """ + ... + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: "pa.DataType", + metadata: dict[str, Any], + converter: TypeConverterProtocol, + ) -> LogicalTypeProtocol: + """Reconstruct a LogicalType from Arrow schema metadata (read path). + + Args: + arrow_extension_name: The Arrow extension type name from the schema. + storage_type: The underlying Arrow storage type (already resolved bottom-up). + metadata: Full parsed metadata JSON dict. Always contains ``"category"``. + converter: The active converter for recursive field-type resolution. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot reconstruct a type for the given name. + """ + ... +``` + +- [ ] **Step 4: Run tests to confirm they pass** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v 2>&1 | tail -30 +``` +Expected: All tests pass (some existing tests about the OLD signatures will now fail — that's expected and will be fixed in Task 2). + +- [ ] **Step 5: Update existing stubs in `test_protocols.py` to use new signatures** + +Replace `_StubLogicalType` and `_StubFactory` in `tests/test_extension_types/test_protocols.py`: + +```python +class _StubLogicalType: + """Minimal conforming implementation of LogicalTypeProtocol for use in tests.""" + + _ArrowExtClass = make_arrow_extension_type("test.module.MyType", pa.large_string()) + + @property + def logical_type_name(self) -> str: + return "test.module.MyType" + + @property + def python_type(self) -> type: + return str + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return self._ArrowExtClass() + + def get_polars_extension_type(self) -> pl.BaseExtension: + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__("test.module.MyType", pl.String, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + return _PolarsExt() + + def python_to_storage(self, value, converter): # converter param added + return str(value) + + def storage_to_python(self, storage_value, converter): # converter param added + return storage_value + + +class _StubFactory: + """Minimal conforming implementation of LogicalTypeFactoryProtocol for use in tests.""" + + def supports_class(self, python_type): # new method + return True + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata, converter): + return _StubLogicalType() + + def create_for_python_type(self, python_type, converter): + return _StubLogicalType() +``` + +Also update the test that calls the old signatures: +```python +def test_conforming_class_satisfies_protocol(): + lt: LogicalTypeProtocol = _StubLogicalType() + assert lt.logical_type_name == "test.module.MyType" + assert lt.python_type is str + assert lt.get_arrow_extension_type().extension_name == "test.module.MyType" + assert isinstance(lt.get_polars_extension_type(), pl.BaseExtension) + assert lt.python_to_storage(42, None) == "42" # pass converter=None + assert lt.storage_to_python("hello", None) == "hello" # pass converter=None + + +def test_logical_type_factory_create_returns_logical_type(): + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + result = factory.reconstruct_from_arrow( + "test.ext", pa.large_utf8(), {"category": "Test"}, converter=None + ) + assert isinstance(result, LogicalTypeProtocol) + + +def test_factory_create_for_python_type_conformance(): + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + assert isinstance(factory, LogicalTypeFactoryProtocol) + result = factory.create_for_python_type(str, converter=None) + assert isinstance(result, LogicalTypeProtocol) +``` + +- [ ] **Step 6: Run all protocol tests** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v 2>&1 | tail -20 +``` +Expected: All pass. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py tests/test_extension_types/test_protocols.py +git commit -m "feat(extension_types): add TypeConverterProtocol; update factory/logical-type protocols with converter param and supports_class" +``` + +--- + +## Task 2: Update built-in logical types for protocol conformance + +**Files:** +- Modify: `src/orcapod/extension_types/builtin_logical_types.py` +- Modify: `tests/test_extension_types/test_builtin_logical_types.py` + +- [ ] **Step 1: Write failing tests** + +Add to `tests/test_extension_types/test_builtin_logical_types.py`: + +```python +def test_logical_path_python_to_storage_accepts_converter(): + """python_to_storage now accepts a converter param (ignored).""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + lt = LogicalPath() + import pathlib + result = lt.python_to_storage(pathlib.Path("/tmp/foo"), converter=None) + assert result == "/tmp/foo" + + +def test_logical_uuid_python_to_storage_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + import uuid as uuid_module + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result = lt.python_to_storage(u, converter=None) + assert result == u.bytes + + +def test_logical_upath_storage_to_python_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + lt = LogicalUPath() + from upath import UPath + result = lt.storage_to_python("s3://bucket/key", converter=None) + assert isinstance(result, UPath) +``` + +- [ ] **Step 2: Run to confirm failures** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v -k "accepts_converter" 2>&1 | tail -20 +``` +Expected: TypeError — unexpected keyword argument. + +- [ ] **Step 3: Update all three classes in `builtin_logical_types.py`** + +For `LogicalPath`: +```python +def python_to_storage(self, value: Any, converter: Any = None) -> str: + return str(value) + +def storage_to_python(self, storage_value: Any, converter: Any = None) -> pathlib.Path: + return pathlib.Path(storage_value) +``` + +For `LogicalUPath`: +```python +def python_to_storage(self, value: Any, converter: Any = None) -> str: + return str(value) + +def storage_to_python(self, storage_value: Any, converter: Any = None) -> UPath: + return UPath(storage_value) +``` + +For `LogicalUUID`: +```python +def python_to_storage(self, value: Any, converter: Any = None) -> bytes: + return value.bytes + +def storage_to_python(self, storage_value: Any, converter: Any = None) -> _uuid_module.UUID: + return _uuid_module.UUID(bytes=bytes(storage_value)) +``` + +Also add `TYPE_CHECKING` import for `TypeConverterProtocol` in the type hint: +```python +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.extension_types.protocols import TypeConverterProtocol +``` + +And use in signatures: +```python +def python_to_storage(self, value: Any, converter: "TypeConverterProtocol | None" = None) -> str: +``` + +- [ ] **Step 4: Run new tests** + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v 2>&1 | tail -20 +``` +Expected: All pass. + +- [ ] **Step 5: Also update test call sites that call without converter** + +Search for existing direct calls to `python_to_storage` / `storage_to_python` in the test file (they have no `converter` arg — that's fine since we added `converter=None` default). + +```bash +uv run pytest tests/test_extension_types/test_builtin_logical_types.py -v 2>&1 | tail -5 +``` +Expected: All pass (defaults handle existing calls). + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/builtin_logical_types.py tests/test_extension_types/test_builtin_logical_types.py +git commit -m "feat(extension_types): add converter param to built-in logical type python_to_storage/storage_to_python" +``` + +--- + +## Task 3: Add `register_python_class` to `UniversalTypeConverter` + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Modify: `tests/test_semantic_types/test_universal_converter.py` + +- [ ] **Step 1: Write failing tests** + +Add to `tests/test_semantic_types/test_universal_converter.py`: + +```python +import dataclasses +import uuid as _uuid_module +import pathlib +from typing import Optional + +import pyarrow as pa +import pytest + +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# ── Helpers ───────────────────────────────────────────────────────────────── + +def _make_registry_with_builtins() -> LogicalTypeRegistry: + """Registry with LogicalPath, LogicalUUID, LogicalUPath pre-registered.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath, LogicalUUID, LogicalUPath + return LogicalTypeRegistry(logical_types=[LogicalPath(), LogicalUUID(), LogicalUPath()]) + + +def _make_converter(registry: LogicalTypeRegistry | None = None) -> UniversalTypeConverter: + if registry is None: + registry = _make_registry_with_builtins() + return UniversalTypeConverter(logical_type_registry=registry) + + +# ── register_python_class tests ────────────────────────────────────────────── + +def test_register_python_class_primitive_int(): + converter = _make_converter() + assert converter.register_python_class(int) == pa.int64() + + +def test_register_python_class_primitive_str(): + converter = _make_converter() + assert converter.register_python_class(str) == pa.large_string() + + +def test_register_python_class_list_of_int(): + converter = _make_converter() + result = converter.register_python_class(list[int]) + assert result == pa.large_list(pa.int64()) + + +def test_register_python_class_optional_str(): + converter = _make_converter() + result = converter.register_python_class(Optional[str]) + assert result == pa.large_string() + + +def test_register_python_class_dict_str_int(): + converter = _make_converter() + result = converter.register_python_class(dict[str, int]) + expected = pa.large_list(pa.struct([pa.field("key", pa.large_string()), pa.field("value", pa.int64())])) + assert result == expected + + +def test_register_python_class_set_of_str(): + converter = _make_converter() + result = converter.register_python_class(set[str]) + assert result == pa.large_list(pa.large_string()) + + +def test_register_python_class_registry_hit_path(): + """pathlib.Path is pre-registered → returns the orcapod.path extension type.""" + converter = _make_converter() + result = converter.register_python_class(pathlib.Path) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == "orcapod.path" + + +def test_register_python_class_uuid_registry_hit(): + converter = _make_converter() + result = converter.register_python_class(_uuid_module.UUID) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == "orcapod.uuid" + + +def test_register_python_class_factory_dispatch(): + """A custom class triggers factory synthesis and caches the result.""" + import uuid as _u + import polars as pl + + class _Base: + pass + + class _Child(_Base): + pass + + ext_name = f"test.custom.{_u.uuid4().hex[:8]}" + ArrowExt = make_arrow_extension_type(ext_name, pa.large_string()) + PolarsExt = make_polars_extension_type(ext_name, pa.large_string()) + synthesized_calls = [] + + class _Factory: + def supports_class(self, python_type): + return issubclass(python_type, _Base) + def create_for_python_type(self, python_type, converter): + synthesized_calls.append(python_type) + class _LT: + logical_type_name = ext_name + python_type_ = _Child + python_type = _Child + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return PolarsExt() + def python_to_storage(self, v, c=None): return str(v) + def storage_to_python(self, v, c=None): return v + return _LT() + def reconstruct_from_arrow(self, name, storage, meta, converter): pass + + registry = _make_registry_with_builtins() + registry.register_logical_type_factory(_Factory(), python_bases=[_Base]) + converter = _make_converter(registry) + + result = converter.register_python_class(_Child) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == ext_name + assert _Child in synthesized_calls + + # Second call is a registry hit — factory NOT called again + result2 = converter.register_python_class(_Child) + assert result2 == result + assert len(synthesized_calls) == 1 + + +def test_register_python_class_cycle_detection(): + """Cyclic type synthesis raises TypeError.""" + import uuid as _u + import polars as pl + + class _CycleClass: + pass + + class _CycleFactory: + def supports_class(self, python_type): + return python_type is _CycleClass + def create_for_python_type(self, python_type, converter): + # Intentionally trigger a cycle + converter.register_python_class(_CycleClass) + def reconstruct_from_arrow(self, name, storage, meta, converter): pass + + registry = _make_registry_with_builtins() + registry.register_logical_type_factory(_CycleFactory(), python_bases=[_CycleClass]) + converter = _make_converter(registry) + + with pytest.raises(TypeError, match="[Cc]ircular"): + converter.register_python_class(_CycleClass) +``` + +- [ ] **Step 2: Run to confirm failures** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v -k "register_python_class" 2>&1 | tail -30 +``` +Expected: AttributeError — `UniversalTypeConverter` has no attribute `register_python_class`. + +- [ ] **Step 3: Implement `register_python_class` in `UniversalTypeConverter`** + +Add these methods to `UniversalTypeConverter` (after `__init__`): + +```python +def register_python_class(self, annotation: Any) -> "pa.DataType": + """Register a Python type annotation and return its Arrow type. + + Traverses generic annotations recursively. For each concrete class found, + either returns from the primitive map or registry (cache hit), or + synthesises via factory and registers the result. + + Args: + annotation: A Python type or generic alias (e.g. ``list[str]``, + ``Optional[uuid.UUID]``, a dataclass type). + + Returns: + The Arrow ``pa.DataType`` corresponding to ``annotation``. + + Raises: + TypeError: If a concrete class has no registered ``LogicalType`` and + no factory covers it, or if a circular dependency is detected. + ValueError: If a complex (non-Optional) union is encountered. + """ + import types as _types_mod + + type_map = _get_python_to_arrow_map() + + # Primitive map hit + if annotation in type_map: + return type_map[annotation] + + origin = get_origin(annotation) + args = get_args(annotation) + + # Optional[T] / T | None → strip None arm + if origin is typing.Union or origin is _types_mod.UnionType: + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1: + return self.register_python_class(non_none[0]) + raise ValueError( + f"Complex unions with multiple non-None types are not supported: " + f"{annotation!r}. Only Optional[T] (T | None) is allowed." + ) + + # list[T] → pa.large_list(T) + if origin is list: + return pa.large_list(self.register_python_class(args[0])) + + # set[T] → pa.large_list(T) + if origin is set: + return pa.large_list(self.register_python_class(args[0])) + + # dict[K, V] → pa.large_list(struct{key: K, value: V}) + if origin is dict: + key_arrow = self.register_python_class(args[0]) + val_arrow = self.register_python_class(args[1]) + return pa.large_list( + pa.struct([pa.field("key", key_arrow), pa.field("value", val_arrow)]) + ) + + # Concrete class — registry or factory dispatch + if isinstance(annotation, type): + if self._logical_type_registry is None: + raise TypeError( + f"No LogicalTypeRegistry configured — cannot register {annotation!r}. " + f"Provide logical_type_registry at converter construction time." + ) + + # Registry hit (already synthesised) + lt = self._logical_type_registry.get_by_python_type(annotation) + if lt is not None: + return lt.get_arrow_extension_type() + + # Cycle detection + if annotation in self._in_progress: + raise TypeError( + f"Circular type dependency detected while synthesising " + f"LogicalType for {annotation!r}." + ) + + # Factory dispatch via MRO walk + factory = self._find_factory_for_class(annotation) + if factory is None: + raise TypeError( + f"No LogicalType or LogicalTypeFactory registered for {annotation!r}. " + f"Register a factory: converter.register_logical_type_factory(factory, " + f"python_bases=[])" + ) + + self._in_progress.add(annotation) + try: + lt = factory.create_for_python_type(annotation, converter=self) + self._logical_type_registry.register_logical_type(lt) + finally: + self._in_progress.discard(annotation) + + return lt.get_arrow_extension_type() + + raise ValueError(f"Unsupported annotation: {annotation!r}") + +def _find_factory_for_class( + self, + python_type: type, +) -> "LogicalTypeFactoryProtocol | None": + """Find the most-specific registered factory for ``python_type``. + + Walks ``python_type.__mro__`` and returns the first factory in + ``_python_class_factories`` whose ``supports_class(python_type)`` returns True. + Falls back to an ``issubclass`` scan for ABC-registered factories. + + Args: + python_type: Concrete Python class to find a factory for. + + Returns: + The matching ``LogicalTypeFactoryProtocol``, or ``None`` if none found. + """ + factories = self._logical_type_registry._python_class_factories + + # MRO walk — most-specific base first + for base in python_type.__mro__: + factory = factories.get(base) + if factory is not None: + if hasattr(factory, "supports_class") and factory.supports_class(python_type): + return factory + elif not hasattr(factory, "supports_class"): + # Factories without supports_class are treated as unconditional matches + return factory + + # issubclass fallback for ABC-registered factories + for base, factory in factories.items(): + try: + if issubclass(python_type, base): + if hasattr(factory, "supports_class"): + if factory.supports_class(python_type): + return factory + else: + return factory + except TypeError: + continue + + return None +``` + +Also add `_in_progress: set[type] = set()` to `__init__`: + +```python +# In __init__, after the existing cache initializations: +self._in_progress: set[type] = set() +``` + +And add `TYPE_CHECKING` import for `LogicalTypeFactoryProtocol`: +```python +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol +``` + +- [ ] **Step 4: Run tests** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v -k "register_python_class" 2>&1 | tail -30 +``` +Expected: All `register_python_class` tests pass. + +- [ ] **Step 5: Run full test suite for this module** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v 2>&1 | tail -20 +``` +Expected: Existing tests still pass. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py tests/test_semantic_types/test_universal_converter.py +git commit -m "feat(universal_converter): add register_python_class with recursive traversal, factory dispatch, and cycle detection" +``` + +--- + +## Task 4: Add `register_storage_type` to `UniversalTypeConverter` + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Modify: `tests/test_semantic_types/test_universal_converter.py` + +- [ ] **Step 1: Write failing tests** + +Add to `tests/test_semantic_types/test_universal_converter.py`: + +```python +# ── register_storage_type tests ────────────────────────────────────────────── + +def test_register_storage_type_primitive_int(): + converter = _make_converter() + assert converter.register_storage_type(pa.int64()) == pa.int64() + + +def test_register_storage_type_primitive_large_string(): + converter = _make_converter() + assert converter.register_storage_type(pa.large_string()) == pa.large_string() + + +def test_register_storage_type_extension_type_registry_hit(): + """An already-registered extension type is returned unchanged (no-op).""" + converter = _make_converter() + # orcapod.uuid is pre-registered in the builtin registry + from orcapod.extension_types.builtin_logical_types import LogicalUUID + uuid_ext = LogicalUUID().get_arrow_extension_type() + result = converter.register_storage_type(uuid_ext) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == "orcapod.uuid" + + +def test_register_storage_type_struct_recurses(): + """Structs are traversed field by field; resolved field types are returned.""" + converter = _make_converter() + struct_type = pa.struct([pa.field("name", pa.large_string()), pa.field("count", pa.int64())]) + result = converter.register_storage_type(struct_type) + assert pa.types.is_struct(result) + assert result.field("name").type == pa.large_string() + assert result.field("count").type == pa.int64() + + +def test_register_storage_type_large_list_recurses(): + converter = _make_converter() + list_type = pa.large_list(pa.int32()) + result = converter.register_storage_type(list_type) + assert pa.types.is_large_list(result) + assert result.value_type == pa.int32() + + +def test_register_storage_type_extension_miss_dispatches_to_factory(): + """An unregistered extension type triggers factory.reconstruct_from_arrow.""" + import json + import uuid as _u + import polars as pl + + ext_name = f"test.reconstruct.{_u.uuid4().hex[:8]}" + category = "test.reconstruct" + metadata = json.dumps({"category": category}).encode() + ArrowExt = make_arrow_extension_type(ext_name, pa.large_string(), metadata=metadata) + PolarsExt = make_polars_extension_type(ext_name, pa.large_string()) + + class _LT: + logical_type_name = ext_name + python_type = str + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return PolarsExt() + def python_to_storage(self, v, c=None): return str(v) + def storage_to_python(self, v, c=None): return v + + class _Factory: + def supports_class(self, t): return False + def create_for_python_type(self, t, converter): pass + def reconstruct_from_arrow(self, name, storage_type, meta, converter): + return _LT() + + registry = _make_registry_with_builtins() + registry.register_logical_type_factory(_Factory(), category=category) + converter = _make_converter(registry) + + ext_instance = ArrowExt() + result = converter.register_storage_type(ext_instance) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == ext_name + + # Second call: registry hit → same result, factory NOT called again + result2 = converter.register_storage_type(ext_instance) + assert result2.extension_name == ext_name + + +def test_register_storage_type_nested_struct_with_extension(): + """Extension type nested inside a struct field is resolved bottom-up.""" + import json + import uuid as _u + import polars as pl + + ext_name = f"test.nested.{_u.uuid4().hex[:8]}" + category = "test.nested" + metadata = json.dumps({"category": category}).encode() + ArrowExt = make_arrow_extension_type(ext_name, pa.large_string(), metadata=metadata) + PolarsExt = make_polars_extension_type(ext_name, pa.large_string()) + + class _LT: + logical_type_name = ext_name + python_type = str + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return PolarsExt() + def python_to_storage(self, v, c=None): return str(v) + def storage_to_python(self, v, c=None): return v + + class _Factory: + def supports_class(self, t): return False + def create_for_python_type(self, t, converter): pass + def reconstruct_from_arrow(self, name, storage_type, meta, converter): + return _LT() + + registry = _make_registry_with_builtins() + registry.register_logical_type_factory(_Factory(), category=category) + converter = _make_converter(registry) + + ext_instance = ArrowExt() + struct_with_ext = pa.struct([pa.field("id", pa.int64()), pa.field("tag", ext_instance)]) + result = converter.register_storage_type(struct_with_ext) + + assert pa.types.is_struct(result) + assert result.field("id").type == pa.int64() + assert isinstance(result.field("tag").type, pa.ExtensionType) + assert result.field("tag").type.extension_name == ext_name +``` + +- [ ] **Step 2: Run to confirm failures** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v -k "register_storage_type" 2>&1 | tail -30 +``` +Expected: AttributeError — `register_storage_type` not defined. + +- [ ] **Step 3: Implement `register_storage_type` and `_ensure_extension_type_info` in `UniversalTypeConverter`** + +```python +def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Register extension types found in ``arrow_type`` and return the resolved type. + + Traverses Arrow types recursively in a bottom-up manner: + - Primitives are returned unchanged. + - ``pa.ExtensionType`` instances that are already registered are returned as-is. + - Unregistered extension types: the storage type is resolved first (bottom-up), + then the factory dispatches on the ``"category"`` metadata key. + - Structs: each field's type is resolved; a new struct with resolved fields is returned. + - Lists: the value type is resolved; a new list type with the resolved value is returned. + + Args: + arrow_type: An Arrow type to traverse and register. + + Returns: + The resolved Arrow type with extension types embedded. + """ + # Extension type + if isinstance(arrow_type, pa.ExtensionType): + ext_name = arrow_type.extension_name + if self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name(ext_name) + if lt is not None: + return lt.get_arrow_extension_type() + # Registry miss — extract info and register + raw_meta = arrow_type.__arrow_ext_serialize__() + ext_meta = raw_meta if raw_meta else None + resolved_storage = self.register_storage_type(arrow_type.storage_type) + return self._ensure_extension_type_info(ext_name, ext_meta, resolved_storage) + + # Struct type — recurse into each field + if pa.types.is_struct(arrow_type): + resolved_fields = [] + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + resolved_type = self.register_storage_type(field.type) + resolved_fields.append(pa.field(field.name, resolved_type, nullable=field.nullable)) + return pa.struct(resolved_fields) + + # Large list type + if pa.types.is_large_list(arrow_type): + resolved_value = self.register_storage_type(arrow_type.value_type) + return pa.large_list(resolved_value) + + # List type + if pa.types.is_list(arrow_type): + resolved_value = self.register_storage_type(arrow_type.value_type) + return pa.list_(resolved_value) + + # All other types (primitives, timestamps, binary, etc.) — return as-is + return arrow_type + +def _ensure_extension_type_info( + self, + arrow_extension_name: str, + extension_metadata: bytes | None, + storage_type: "pa.DataType", +) -> "pa.DataType": + """Register an extension type from (name, metadata, storage_type) info. + + Called by ``register_storage_type`` for in-memory ``pa.ExtensionType`` objects, + and by ``register_discovered_extensions`` for the field-metadata (Parquet) channel. + The ``storage_type`` must already be resolved (nested extension types registered). + + Args: + arrow_extension_name: Arrow extension name (``ARROW:extension:name``). + extension_metadata: Raw metadata bytes, expected to be UTF-8 JSON with + at least a ``"category"`` key. ``None`` or empty bytes if absent. + storage_type: Underlying Arrow storage type (already bottom-up resolved). + + Returns: + The Arrow extension type after registration. + + Raises: + ValueError: If metadata is missing, malformed, lacks ``"category"``, or + no factory is registered for the category. + """ + import json as _json + + if self._logical_type_registry is None: + raise ValueError( + f"No LogicalTypeRegistry configured — cannot register extension type " + f"{arrow_extension_name!r}." + ) + + # Registry hit — already registered + lt = self._logical_type_registry.get_by_arrow_extension_name(arrow_extension_name) + if lt is not None: + return lt.get_arrow_extension_type() + + # Missing metadata — cannot auto-register + if not extension_metadata: + raise ValueError( + f"Extension type {arrow_extension_name!r} has no extension metadata. " + f"Types without a metadata category tag cannot be auto-registered via a factory. " + f"Pre-register them explicitly via converter.register_logical_type(lt)." + ) + + # Parse JSON metadata + try: + metadata_dict = _json.loads(extension_metadata.decode("utf-8")) + except (UnicodeDecodeError, _json.JSONDecodeError) as exc: + raise ValueError( + f"Extension type {arrow_extension_name!r} has metadata that is not valid " + f"UTF-8 JSON: {extension_metadata!r}. Parse error: {exc}." + ) from exc + + if not isinstance(metadata_dict, dict): + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata decoded to a non-object " + f"JSON value: {metadata_dict!r}." + ) + + if "category" not in metadata_dict: + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata has no \"category\" key: " + f"{metadata_dict}." + ) + + category = metadata_dict["category"] + if not isinstance(category, str): + raise ValueError( + f"Extension type {arrow_extension_name!r} metadata \"category\" is not a " + f"string: {category!r}." + ) + + # Look up factory by category + factory = self._logical_type_registry._category_factories.get(category) + if factory is None: + raise ValueError( + f"No LogicalTypeFactory registered for category {category!r}. " + f"Cannot register extension type {arrow_extension_name!r}." + ) + + # Reconstruct and register + logical_type = factory.reconstruct_from_arrow( + arrow_extension_name, storage_type, metadata_dict, converter=self + ) + self._logical_type_registry.register_logical_type(logical_type) + return logical_type.get_arrow_extension_type() +``` + +- [ ] **Step 4: Run tests** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v -k "register_storage_type" 2>&1 | tail -30 +``` +Expected: All pass. + +- [ ] **Step 5: Run full converter test suite** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v 2>&1 | tail -10 +``` + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py tests/test_semantic_types/test_universal_converter.py +git commit -m "feat(universal_converter): add register_storage_type with bottom-up recursive traversal" +``` + +--- + +## Task 5: Add `python_to_storage`, `storage_to_python`, and registration pass-throughs; update converter dispatch + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Modify: `tests/test_semantic_types/test_universal_converter.py` + +- [ ] **Step 1: Write failing tests** + +```python +# ── python_to_storage / storage_to_python / pass-through tests ────────────── + +def test_python_to_storage_for_registered_type(): + """python_to_storage uses the logical type's converter for registered types.""" + converter = _make_converter() + import pathlib + result = converter.python_to_storage(pathlib.Path("/tmp/bar"), pathlib.Path) + assert result == "/tmp/bar" + + +def test_storage_to_python_for_registered_type(): + converter = _make_converter() + import pathlib + result = converter.storage_to_python("/tmp/bar", pathlib.Path) + assert isinstance(result, pathlib.Path) + assert result == pathlib.Path("/tmp/bar") + + +def test_python_to_storage_for_int(): + converter = _make_converter() + assert converter.python_to_storage(42, int) == 42 + + +def test_register_logical_type_passthrough(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + registry = LogicalTypeRegistry() + converter = UniversalTypeConverter(logical_type_registry=registry) + lt = LogicalPath() + converter.register_logical_type(lt) + assert registry.get_by_python_type(import_pathlib_path()) is lt + + +def import_pathlib_path(): + import pathlib; return pathlib.Path + + +def test_register_logical_type_factory_passthrough(): + import uuid as _u + import polars as pl + + class _Factory: + def supports_class(self, t): return False + def create_for_python_type(self, t, converter): pass + def reconstruct_from_arrow(self, name, storage, meta, converter): pass + + registry = LogicalTypeRegistry() + converter = UniversalTypeConverter(logical_type_registry=registry) + factory = _Factory() + converter.register_logical_type_factory(factory, category="test.cat") + assert registry._category_factories.get("test.cat") is factory +``` + +- [ ] **Step 2: Run to confirm failures** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -v -k "python_to_storage or storage_to_python or passthrough" 2>&1 | tail -20 +``` + +- [ ] **Step 3: Add methods to `UniversalTypeConverter`** + +```python +def python_to_storage(self, value: Any, annotation: Any) -> Any: + """Convert a Python value to its Arrow storage representation. + + Thin wrapper over ``get_python_to_arrow_converter`` for use by + ``DataclassLogicalType`` and other logical types that delegate per-field + conversion back to the converter. + + Args: + value: A Python object. + annotation: The Python type annotation for ``value``. + + Returns: + A value in Arrow storage format. + """ + converter_fn = self.get_python_to_arrow_converter(annotation) + return converter_fn(value) + +def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: + """Convert an Arrow storage value back to a Python object. + + Args: + storage_value: A scalar or element from an Arrow storage array. + annotation: The Python type annotation to convert back to. + + Returns: + A Python object of the type described by ``annotation``. + """ + arrow_type = self.python_type_to_arrow_type(annotation) + converter_fn = self.get_arrow_to_python_converter(arrow_type) + return converter_fn(storage_value) + +def register_logical_type(self, lt: "LogicalTypeProtocol") -> None: + """Register a ``LogicalTypeProtocol`` instance. + + Pass-through to the internal ``LogicalTypeRegistry``. + + Args: + lt: The logical type to register. + """ + if self._logical_type_registry is None: + raise ValueError("No LogicalTypeRegistry configured on this converter.") + self._logical_type_registry.register_logical_type(lt) + +def register_logical_type_factory( + self, + factory: "LogicalTypeFactoryProtocol", + *, + category: "str | None" = None, + python_bases: "Iterable[type]" = (), +) -> None: + """Register a ``LogicalTypeFactoryProtocol`` instance. + + Pass-through to the internal ``LogicalTypeRegistry``. + + Args: + factory: The factory to register. + category: If given, registers factory as the read-side handler for + Arrow extension types with this ``"category"`` metadata value. + python_bases: Zero or more Python base classes to register as write-side + dispatch keys for this factory. + """ + if self._logical_type_registry is None: + raise ValueError("No LogicalTypeRegistry configured on this converter.") + self._logical_type_registry.register_logical_type_factory( + factory, category=category, python_bases=python_bases + ) +``` + +Also add `Iterable` to the imports in `universal_converter.py`: +```python +from collections.abc import Callable, Iterable, Mapping +``` + +And add TYPE_CHECKING imports: +```python +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol +``` + +- [ ] **Step 4: Update `_create_python_to_arrow_converter` to pass `converter=self`** + +In `_create_python_to_arrow_converter`, find this block: +```python +if self._logical_type_registry is not None and isinstance(python_type, type): + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + return lt.python_to_storage +``` + +Replace with: +```python +if self._logical_type_registry is not None and isinstance(python_type, type): + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + _lt = lt + _self = self + return lambda value: _lt.python_to_storage(value, _self) +``` + +- [ ] **Step 5: Update `_create_arrow_to_python_converter` to pass `converter=self`** + +In `_create_arrow_to_python_converter`, find: +```python +if isinstance(arrow_type, pa.ExtensionType) and self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name( + arrow_type.extension_name + ) + if lt is not None: + return lt.storage_to_python +``` + +Replace with: +```python +if isinstance(arrow_type, pa.ExtensionType) and self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name( + arrow_type.extension_name + ) + if lt is not None: + _lt = lt + _self = self + return lambda storage_value: _lt.storage_to_python(storage_value, _self) +``` + +- [ ] **Step 6: Run tests** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py tests/test_extension_types/test_builtin_logical_types.py -v 2>&1 | tail -20 +``` +Expected: All pass. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py tests/test_semantic_types/test_universal_converter.py +git commit -m "feat(universal_converter): add python_to_storage, storage_to_python, and registration pass-throughs; wire converter=self into logical type dispatch" +``` + +--- + +## Task 6: Simplify `ensure_types_registered_for_schemas` + remove `ensure_*` from registry + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Modify: `src/orcapod/extension_types/registry.py` +- Modify: `tests/test_extension_types/test_registry.py` + +- [ ] **Step 1: Update `ensure_types_registered_for_schemas` in `UniversalTypeConverter`** + +Replace the existing method: + +```python +def ensure_types_registered_for_schemas(self, *schemas: Schema) -> None: + """Ensure a LogicalType is registered for every annotation in schemas. + + Calls ``register_python_class`` for each annotation, which recursively + resolves nested types and synthesises via factory if needed. + When no ``LogicalTypeRegistry`` is configured, this is a no-op. + + Args: + *schemas: One or more ``Schema`` mappings (column name → Python type). + + Raises: + TypeError: If a leaf class has no registered ``LogicalType`` and + no registered factory covers it. + """ + if self._logical_type_registry is None: + return + for schema in schemas: + for annotation in schema.values(): + self.register_python_class(annotation) +``` + +- [ ] **Step 2: Run existing ensure_types tests to verify nothing breaks** + +```bash +uv run pytest tests/ -v -k "ensure_types" 2>&1 | tail -20 +``` +Expected: Pass. + +- [ ] **Step 3: Find and update registry tests that test `ensure_*` methods** + +Check which tests in `test_registry.py` test `ensure_logical_type_for_python_class` and `ensure_extension_type`: + +```bash +grep -n "ensure_logical_type\|ensure_extension_type" tests/test_extension_types/test_registry.py +``` + +- [ ] **Step 4: Remove `ensure_*` tests from `test_registry.py` and add converter pass-through tests** + +Remove any test functions that directly test `ensure_logical_type_for_python_class` or `ensure_extension_type` on the registry (they are removed from the public API). + +Add this test to `test_registry.py`: + +```python +def test_registry_does_not_expose_ensure_methods(): + """ensure_logical_type_for_python_class and ensure_extension_type are removed.""" + registry = LogicalTypeRegistry() + assert not hasattr(registry, "ensure_logical_type_for_python_class") + assert not hasattr(registry, "ensure_extension_type") +``` + +- [ ] **Step 5: Remove `ensure_logical_type_for_python_class` and `ensure_extension_type` from `registry.py`** + +In `src/orcapod/extension_types/registry.py`, delete the `ensure_extension_type` method (lines ~355-467) and the `ensure_logical_type_for_python_class` method (lines ~469-577). + +The public surface retained: `register_logical_type`, `register_logical_type_factory`, `get_by_python_type`, `get_by_arrow_extension_name`, `get_by_logical_name`. + +- [ ] **Step 6: Run tests** + +```bash +uv run pytest tests/test_extension_types/test_registry.py tests/test_semantic_types/ -v 2>&1 | tail -20 +``` +Expected: All pass (ensure_* tests replaced). + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py src/orcapod/extension_types/registry.py tests/test_extension_types/test_registry.py +git commit -m "refactor(registry): remove ensure_* methods; simplify ensure_types_registered_for_schemas to use register_python_class" +``` + +--- + +## Task 7: Create `DataclassLogicalType` in `extension_types/dataclass_handler.py` + +**Files:** +- Create: `src/orcapod/extension_types/dataclass_handler.py` +- Create: `tests/test_extension_types/test_dataclass_handler.py` + +- [ ] **Step 1: Write failing tests for `DataclassLogicalType`** + +Create `tests/test_extension_types/test_dataclass_handler.py`: + +```python +"""Tests for DataclassLogicalType and DataclassHandlerFactory.""" + +from __future__ import annotations + +import dataclasses +import uuid as _uuid_module +from typing import Any + +import pyarrow as pa +import pytest + + +# ── Helpers ───────────────────────────────────────────────────────────────── + +class _StubConverter: + """Minimal converter stub for DataclassLogicalType tests.""" + + def python_to_storage(self, value, annotation): + if annotation is str: + return str(value) + if annotation is int: + return int(value) + return value + + def storage_to_python(self, storage_value, annotation): + if annotation is str: + return str(storage_value) + if annotation is int: + return int(storage_value) + return storage_value + + def register_python_class(self, annotation): + if annotation is str: + return pa.large_string() + if annotation is int: + return pa.int64() + raise ValueError(f"No mapping for {annotation}") + + +# ── DataclassLogicalType tests ─────────────────────────────────────────────── + +def test_dataclass_logical_type_is_importable(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + assert DataclassLogicalType is not None + + +def test_dataclass_logical_type_protocol_conformance(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + from orcapod.extension_types.protocols import LogicalTypeProtocol + + @dataclasses.dataclass + class _MyDC: + name: str + count: int + + storage = pa.struct([pa.field("name", pa.large_string()), pa.field("count", pa.int64())]) + field_annotations = [("name", str), ("count", int)] + lt = DataclassLogicalType( + logical_name="tests.MyDC", + python_type=_MyDC, + storage_type=storage, + field_annotations=field_annotations, + ) + assert isinstance(lt, LogicalTypeProtocol) + + +def test_dataclass_logical_type_python_to_storage(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + + @dataclasses.dataclass + class _Point: + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = DataclassLogicalType("tests.Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.python_to_storage(_Point(x=3, y=7), converter) + assert result == {"x": 3, "y": 7} + + +def test_dataclass_logical_type_storage_to_python(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + + @dataclasses.dataclass + class _Point: + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = DataclassLogicalType("tests.Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.storage_to_python({"x": 3, "y": 7}, converter) + assert isinstance(result, _Point) + assert result.x == 3 + assert result.y == 7 + + +def test_dataclass_logical_type_logical_type_name(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + + @dataclasses.dataclass + class _Foo: + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = DataclassLogicalType("mymod.Foo", _Foo, storage, [("val", str)]) + assert lt.logical_type_name == "mymod.Foo" + + +def test_dataclass_logical_type_python_type(): + from orcapod.extension_types.dataclass_handler import DataclassLogicalType + + @dataclasses.dataclass + class _Bar: + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = DataclassLogicalType("mymod.Bar", _Bar, storage, [("val", str)]) + assert lt.python_type is _Bar +``` + +- [ ] **Step 2: Run to confirm failures** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_handler.py -v 2>&1 | tail -20 +``` +Expected: ImportError — `dataclass_handler` does not exist. + +- [ ] **Step 3: Create `src/orcapod/extension_types/dataclass_handler.py`** + +```python +"""DataclassLogicalType and DataclassHandlerFactory. + +Provides the ``DataclassLogicalType`` logical type implementation and the +``DataclassHandlerFactory`` that synthesises and reconstructs ``DataclassLogicalType`` +instances for Python dataclasses. + +Write path (``create_for_python_type``): + Iterates dataclass fields, delegates field Arrow-type resolution to the converter + via ``register_python_class``, and returns a ``DataclassLogicalType`` backed by + a ``pa.struct`` extension type. + +Read path (``reconstruct_from_arrow``): + Imports the dataclass by fully-qualified class name, resolves field annotations + against the (already bottom-up resolved) storage type, and returns a + ``DataclassLogicalType``. + +Category tag: ``"orcapod.dataclass"`` +""" + +from __future__ import annotations + +import dataclasses +import importlib +import json +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) + +#: Category tag embedded in Arrow extension metadata. Used as the factory dispatch key. +DATACLASS_CATEGORY = "orcapod.dataclass" + + +class DataclassLogicalType: + """Logical type binding a Python dataclass to its Arrow extension type representation. + + Stores the dataclass's fully-qualified class name as the Arrow extension name + and a ``pa.struct`` of the dataclass fields as the storage type. + + No Arrow-type reasoning lives here — all field-type resolution is owned by the + converter and completed before this object is constructed. + + Args: + logical_name: Fully-qualified class name (e.g. ``"mymodule.sub.MyData"``). + Used as both the logical type name and the Arrow extension name. + python_type: The Python dataclass ``type`` object. + storage_type: The Arrow ``pa.StructType`` for the dataclass fields. + field_annotations: Ordered list of ``(field_name, python_annotation)`` pairs + matching the fields in ``storage_type``. + + Example: + >>> lt = DataclassLogicalType( + ... "mymod.Point", Point, + ... pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]), + ... [("x", int), ("y", int)], + ... ) + >>> lt.python_to_storage(Point(1, 2), converter) + {"x": 1, "y": 2} + """ + + def __init__( + self, + logical_name: str, + python_type: type, + storage_type: "pa.StructType", + field_annotations: list[tuple[str, Any]], + ) -> None: + self._logical_name = logical_name + self._python_type = python_type + self._storage_type = storage_type + self._field_annotations = field_annotations + + _metadata = json.dumps({"category": DATACLASS_CATEGORY}).encode("utf-8") + self._arrow_ext_class = make_arrow_extension_type( + logical_name, storage_type, metadata=_metadata + ) + self._arrow_ext: "pa.ExtensionType | None" = None + self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext: "pl.BaseExtension | None" = None + + @property + def logical_type_name(self) -> str: + """Fully-qualified class name used as the logical type identifier.""" + return self._logical_name + + @property + def python_type(self) -> type: + """The Python dataclass type this logical type represents.""" + return self._python_type + + def get_arrow_extension_type(self) -> "pa.ExtensionType": + """Return the Arrow extension type for this dataclass. + + Returns: + A cached ``pa.ExtensionType`` instance with ``extension_name`` equal to + the fully-qualified class name and ``storage_type`` equal to the struct + of the dataclass fields. + """ + if self._arrow_ext is None: + self._arrow_ext = self._arrow_ext_class() + return self._arrow_ext + + def get_polars_extension_type(self) -> "pl.BaseExtension": + """Return the Polars extension type for this dataclass. + + Returns: + A cached ``pl.BaseExtension`` instance. + """ + if self._polars_ext is None: + self._polars_ext = self._polars_ext_class() + return self._polars_ext + + def python_to_storage(self, value: Any, converter: "TypeConverterProtocol") -> dict[str, Any]: + """Convert a dataclass instance to an Arrow-compatible struct dict. + + Iterates ``_field_annotations`` and delegates each field's conversion to + ``converter.python_to_storage``. + + Args: + value: A dataclass instance of type ``python_type``. + converter: The active converter for per-field delegation. + + Returns: + A dict mapping field names to their Arrow storage values. + """ + return { + name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations + } + + def storage_to_python(self, storage_value: Any, converter: "TypeConverterProtocol") -> Any: + """Reconstruct a dataclass instance from an Arrow struct dict. + + Args: + storage_value: A dict mapping field names to Arrow storage values. + converter: The active converter for per-field delegation. + + Returns: + A dataclass instance of type ``python_type``. + """ + kwargs = { + name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations + } + return self._python_type(**kwargs) + + +class DataclassHandlerFactory: + """Stateless factory that synthesises and reconstructs ``DataclassLogicalType`` instances. + + **Write path** (``create_for_python_type``): derives Arrow struct type from the + dataclass fields by delegating to ``converter.register_python_class`` per field. + + **Read path** (``reconstruct_from_arrow``): imports the dataclass by FQCN, matches + fields against the already-resolved ``storage_type``, and returns a + ``DataclassLogicalType``. + + Category tag: ``"orcapod.dataclass"`` + Register with:: + + converter.register_logical_type_factory( + DataclassHandlerFactory(), + category="orcapod.dataclass", + python_bases=[object], + ) + + Example: + >>> factory = DataclassHandlerFactory() + >>> factory.supports_class(MyDataclass) + True + >>> factory.supports_class(str) + False + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if ``python_type`` is a dataclass. + + Args: + python_type: Any Python type. + + Returns: + True if ``dataclasses.is_dataclass(python_type)`` is True. + """ + return dataclasses.is_dataclass(python_type) and isinstance(python_type, type) + + def create_for_python_type( + self, + python_type: type, + converter: "TypeConverterProtocol", + ) -> DataclassLogicalType: + """Synthesise a ``DataclassLogicalType`` for a Python dataclass (write path). + + Derives the FQCN, obtains type hints, and resolves each field's Arrow type + via ``converter.register_python_class``. Rejects local / unnamed classes. + + Args: + python_type: A Python dataclass type. + converter: The active converter for field-type resolution. + + Returns: + A ``DataclassLogicalType`` ready for registration. + + Raises: + ValueError: If ``python_type`` is a local class (no stable FQCN) or + has a ``__qualname__`` that contains ``""``. + """ + import typing + + fqcn = f"{python_type.__module__}.{python_type.__qualname__}" + if "" in fqcn or not python_type.__module__ or python_type.__module__ == "__main__": + pass # allow __main__ classes but reject proper locals + if "" in fqcn: + raise ValueError( + f"Cannot register local class {python_type!r} as a DataclassLogicalType — " + f"local classes have no stable fully-qualified class name and cannot be " + f"reconstructed on read. Define the dataclass at module level." + ) + + try: + hints = typing.get_type_hints(python_type) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {python_type!r}: {exc}" + ) from exc + + arrow_fields = [] + field_annotations = [] + for field in dataclasses.fields(python_type): + if not field.init: + continue + annotation = hints.get(field.name, Any) + arrow_type = converter.register_python_class(annotation) + arrow_fields.append(pa.field(field.name, arrow_type)) + field_annotations.append((field.name, annotation)) + + storage_type = pa.struct(arrow_fields) + logger.debug("DataclassHandlerFactory: synthesised %r for %r", fqcn, python_type) + return DataclassLogicalType(fqcn, python_type, storage_type, field_annotations) + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: "pa.DataType", + metadata: dict[str, Any], + converter: "TypeConverterProtocol", + ) -> DataclassLogicalType: + """Reconstruct a ``DataclassLogicalType`` from Arrow schema metadata (read path). + + Imports the dataclass from its FQCN (``arrow_extension_name``), then matches + the dataclass field annotations against the fields in ``storage_type``. + ``storage_type`` is already bottom-up resolved by ``register_storage_type`` + before this method is called. + + Args: + arrow_extension_name: FQCN of the dataclass (Arrow extension name). + storage_type: Already-resolved ``pa.StructType`` for the dataclass fields. + metadata: Full parsed metadata JSON dict (always contains ``"category"``). + converter: The active converter (not needed here but required by protocol). + + Returns: + A ``DataclassLogicalType`` ready for registration. + + Raises: + ImportError: If the class cannot be imported from ``arrow_extension_name``. + ValueError: If ``storage_type`` is not a struct type. + """ + import typing + + if not pa.types.is_struct(storage_type): + raise ValueError( + f"DataclassHandlerFactory.reconstruct_from_arrow: expected a struct " + f"storage type for {arrow_extension_name!r}, got {storage_type!r}." + ) + + # Import class from FQCN using longest-prefix module walk + cls = _import_from_fqcn(arrow_extension_name) + + try: + hints = typing.get_type_hints(cls) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {cls!r}: {exc}" + ) from exc + + field_annotations = [] + for field in dataclasses.fields(cls): + if not field.init: + continue + annotation = hints.get(field.name, Any) + field_annotations.append((field.name, annotation)) + + logger.debug( + "DataclassHandlerFactory: reconstructed %r from Arrow", arrow_extension_name + ) + return DataclassLogicalType( + arrow_extension_name, cls, storage_type, field_annotations + ) + + +def _import_from_fqcn(fqcn: str) -> type: + """Import a class from its fully-qualified class name. + + Tries module prefixes from longest to shortest. For example, for + ``"mypackage.sub.MyClass"``, tries ``importlib.import_module("mypackage.sub")`` + then ``getattr(module, "MyClass")``. + + Args: + fqcn: Fully-qualified class name, e.g. ``"mypackage.sub.MyClass"``. + + Returns: + The imported class. + + Raises: + ImportError: If no valid module+attribute split can be found. + """ + parts = fqcn.rsplit(".", 1) + if len(parts) != 2: + raise ImportError(f"Cannot import from FQCN {fqcn!r}: no module separator found.") + + module_path, class_name = parts + try: + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + if not dataclasses.is_dataclass(cls) or not isinstance(cls, type): + raise ImportError( + f"{class_name!r} in {module_path!r} is not a dataclass type." + ) + return cls + except (ImportError, AttributeError, ModuleNotFoundError) as exc: + raise ImportError( + f"Cannot import dataclass from FQCN {fqcn!r}: {exc}" + ) from exc +``` + +- [ ] **Step 4: Run dataclass logical type tests** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_handler.py -v -k "DataclassLogicalType or logical_type" 2>&1 | tail -30 +``` +Expected: All DataclassLogicalType tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/dataclass_handler.py tests/test_extension_types/test_dataclass_handler.py +git commit -m "feat(dataclass_handler): implement DataclassLogicalType" +``` + +--- + +## Task 8: `DataclassHandlerFactory` write path tests + verification + +**Files:** +- Modify: `tests/test_extension_types/test_dataclass_handler.py` +- Modify: `src/orcapod/extension_types/dataclass_handler.py` (fixes only) + +- [ ] **Step 1: Add factory write-path tests** + +```python +# Add to tests/test_extension_types/test_dataclass_handler.py + +def _make_full_converter(): + """Make a UniversalTypeConverter with builtin types + DataclassHandlerFactory.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath, LogicalUUID, LogicalUPath + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory, DATACLASS_CATEGORY + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + registry = LogicalTypeRegistry(logical_types=[LogicalPath(), LogicalUUID(), LogicalUPath()]) + factory = DataclassHandlerFactory() + registry.register_logical_type_factory(factory, category=DATACLASS_CATEGORY, python_bases=[object]) + return UniversalTypeConverter(logical_type_registry=registry) + + +def test_factory_supports_class_dataclass(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + @dataclasses.dataclass + class _Dummy: + x: int + + factory = DataclassHandlerFactory() + assert factory.supports_class(_Dummy) is True + + +def test_factory_supports_class_non_dataclass(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + factory = DataclassHandlerFactory() + assert factory.supports_class(str) is False + assert factory.supports_class(int) is False + + +def test_factory_create_flat_dataclass(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory, DataclassLogicalType + + @dataclasses.dataclass + class _Flat: + name: str + count: int + + factory = DataclassHandlerFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_Flat, converter=converter) + + assert isinstance(lt, DataclassLogicalType) + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_struct(storage) + assert storage.field("name").type == pa.large_string() + assert storage.field("count").type == pa.int64() + + +def test_factory_create_dataclass_with_uuid_field(): + """UUID field → orcapod.uuid extension type in storage struct.""" + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + @dataclasses.dataclass + class _WithUUID: + id: _uuid_module.UUID + label: str + + factory = DataclassHandlerFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithUUID, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + id_field_type = storage.field("id").type + assert isinstance(id_field_type, pa.ExtensionType) + assert id_field_type.extension_name == "orcapod.uuid" + + +def test_factory_create_dataclass_with_list_field(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + @dataclasses.dataclass + class _WithList: + tags: list[str] + count: int + + factory = DataclassHandlerFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithList, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_large_list(storage.field("tags").type) + assert storage.field("tags").type.value_type == pa.large_string() + + +def test_factory_create_dataclass_with_dict_field(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + @dataclasses.dataclass + class _WithDict: + meta: dict[str, int] + + factory = DataclassHandlerFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithDict, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + meta_type = storage.field("meta").type + assert pa.types.is_large_list(meta_type) + assert pa.types.is_struct(meta_type.value_type) + field_names = {meta_type.value_type.field(i).name for i in range(meta_type.value_type.num_fields)} + assert field_names == {"key", "value"} + + +def test_factory_rejects_local_class(): + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + def _make_local(): + @dataclasses.dataclass + class _Local: + x: int + return _Local + + LocalClass = _make_local() + factory = DataclassHandlerFactory() + converter = _make_full_converter() + with pytest.raises(ValueError, match="local"): + factory.create_for_python_type(LocalClass, converter=converter) + + +def test_register_python_class_dispatches_to_dataclass_factory(): + """register_python_class on a dataclass triggers DataclassHandlerFactory.""" + converter = _make_full_converter() + + @dataclasses.dataclass + class _MyPoint: + x: int + y: int + + # This is a local class — use a module-level one via register_python_class + # For this test, simulate by directly pre-importing: + # We can't use a local class here due to the FQCN check. + # So we test with the UUID field only as a proxy. + result = converter.register_python_class(_uuid_module.UUID) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == "orcapod.uuid" +``` + +- [ ] **Step 2: Run factory write-path tests** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_handler.py -v 2>&1 | tail -30 +``` +Expected: All pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_dataclass_handler.py src/orcapod/extension_types/dataclass_handler.py +git commit -m "test(dataclass_handler): add DataclassHandlerFactory write-path tests" +``` + +--- + +## Task 9: `DataclassHandlerFactory` read path + Arrow round-trip + +**Files:** +- Modify: `tests/test_extension_types/test_dataclass_handler.py` + +- [ ] **Step 1: Add read-path and round-trip tests** + +```python +# Add to tests/test_extension_types/test_dataclass_handler.py + +# ── Module-level dataclass for round-trip tests ────────────────────────────── + +@dataclasses.dataclass +class _RoundTripPoint: + """Module-level dataclass for round-trip testing.""" + x: int + y: int + + +@dataclasses.dataclass +class _RoundTripRecord: + """Module-level dataclass with a UUID field.""" + record_id: _uuid_module.UUID + label: str + + +# ── Read-path tests ─────────────────────────────────────────────────────────── + +def test_factory_reconstruct_from_arrow(): + """reconstruct_from_arrow rebuilds the logical type from the Arrow struct.""" + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory, DataclassLogicalType + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + metadata = {"category": "orcapod.dataclass"} + fqcn = f"{_RoundTripPoint.__module__}.{_RoundTripPoint.__qualname__}" + + factory = DataclassHandlerFactory() + converter = _make_full_converter() + lt = factory.reconstruct_from_arrow(fqcn, storage, metadata, converter=converter) + + assert isinstance(lt, DataclassLogicalType) + assert lt.python_type is _RoundTripPoint + assert lt.logical_type_name == fqcn + + +def test_factory_reconstruct_from_arrow_invalid_fqcn(): + """ImportError if the FQCN cannot be resolved.""" + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + storage = pa.struct([pa.field("x", pa.int64())]) + factory = DataclassHandlerFactory() + converter = _make_full_converter() + + with pytest.raises(ImportError): + factory.reconstruct_from_arrow( + "nonexistent.module.NoSuchClass", storage, {"category": "orcapod.dataclass"}, converter + ) + + +def test_dataclass_python_to_storage_round_trip(): + """python_to_storage → storage_to_python returns an equivalent dataclass.""" + converter = _make_full_converter() + + # Register _RoundTripPoint via register_python_class + # It's module-level so FQCN is stable + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory, DATACLASS_CATEGORY + factory = DataclassHandlerFactory() + lt = factory.create_for_python_type(_RoundTripPoint, converter=converter) + converter.register_logical_type(lt) + + point = _RoundTripPoint(x=10, y=20) + storage_value = lt.python_to_storage(point, converter) + assert storage_value == {"x": 10, "y": 20} + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripPoint) + assert reconstructed.x == 10 + assert reconstructed.y == 20 + + +def test_dataclass_with_uuid_round_trip(): + """Round-trip a dataclass with a UUID field through python_to_storage / storage_to_python.""" + from orcapod.extension_types.dataclass_handler import DataclassHandlerFactory + + converter = _make_full_converter() + factory = DataclassHandlerFactory() + lt = factory.create_for_python_type(_RoundTripRecord, converter=converter) + converter.register_logical_type(lt) + + u = _uuid_module.UUID("12345678-1234-5678-1234-567812345678") + record = _RoundTripRecord(record_id=u, label="hello") + + storage_value = lt.python_to_storage(record, converter) + assert storage_value["label"] == "hello" + # UUID stored as bytes + assert storage_value["record_id"] == u.bytes + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripRecord) + assert reconstructed.record_id == u + assert reconstructed.label == "hello" +``` + +- [ ] **Step 2: Run read-path and round-trip tests** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_handler.py -v 2>&1 | tail -30 +``` +Expected: All pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_dataclass_handler.py +git commit -m "test(dataclass_handler): add DataclassHandlerFactory read-path and Arrow round-trip tests" +``` + +--- + +## Task 10: DataContext cleanup + context wiring + +**Files:** +- Modify: `src/orcapod/contexts/core.py` +- Modify: `src/orcapod/contexts/__init__.py` +- Modify: `src/orcapod/contexts/registry.py` +- Modify: `src/orcapod/contexts/data/v0.1.json` +- Modify: `src/orcapod/contexts/data/schemas/context_schema.json` +- Modify: `src/orcapod/extension_types/__init__.py` +- Modify: `tests/test_core/function_pod/test_write_side_registration.py` + +- [ ] **Step 1: Remove `logical_type_registry` from `DataContext`** + +In `src/orcapod/contexts/core.py`, remove the `logical_type_registry` field: + +```python +"""Core data structures and exceptions for the OrcaPod context system.""" + +from dataclasses import dataclass + +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + SemanticHasherProtocol, +) +from orcapod.protocols.semantic_types_protocols import TypeConverterProtocol + + +@dataclass +class DataContext: + """Data context containing all versioned components needed for data interpretation. + + Attributes: + context_key: Unique identifier (e.g., "std:v0.1:default") + version: Version string (e.g., "v0.1") + description: Human-readable description + type_converter: Type converter for Python ↔ Arrow conversion and + registration. This is the single public API for all type operations. + arrow_hasher: Arrow table hasher for this context + semantic_hasher: General semantic hasher for this context + type_handler_registry: Registry of TypeHandlerProtocol instances + """ + + context_key: str + version: str + description: str + type_converter: TypeConverterProtocol + arrow_hasher: ArrowHasherProtocol + semantic_hasher: SemanticHasherProtocol + type_handler_registry: TypeHandlerRegistry + + +class ContextValidationError(Exception): + """Raised when context validation fails.""" + pass + + +class ContextResolutionError(Exception): + """Raised when context cannot be resolved.""" + pass +``` + +- [ ] **Step 2: Remove `get_default_logical_type_registry` from `contexts/__init__.py`** + +In `src/orcapod/contexts/__init__.py`: +1. Remove the `from orcapod.extension_types.registry import LogicalTypeRegistry` import +2. Delete the `get_default_logical_type_registry` function +3. Remove `get_default_logical_type_registry` from `__all__` + +- [ ] **Step 3: Update `contexts/registry.py`** + +In `_create_context_from_spec`, remove `logical_type_registry=ref_lut["logical_type_registry"]` from `DataContext(...)` constructor call. Also remove `"logical_type_registry"` from the `required_fields` list: + +```python +required_fields = [ + "context_key", + "version", + "semantic_registry", + "type_converter", + "arrow_hasher", + "semantic_hasher", + "type_handler_registry", + # "logical_type_registry" — removed; registry is internal to type_converter +] +``` + +And update `DataContext(...)` construction: +```python +return DataContext( + context_key=context_key, + version=version, + description=description, + type_converter=ref_lut["type_converter"], + arrow_hasher=ref_lut["arrow_hasher"], + semantic_hasher=ref_lut["semantic_hasher"], + type_handler_registry=ref_lut["type_handler_registry"], + # logical_type_registry removed +) +``` + +- [ ] **Step 4: Update `contexts/data/v0.1.json`** + +Move `logical_type_registry` construction inside `type_converter._config`. Remove `semantic_registry` ref from `type_converter._config`: + +```json +"type_converter": { + "_class": "orcapod.semantic_types.universal_converter.UniversalTypeConverter", + "_config": { + "logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ] + } + } + } +}, +``` + +Also remove the top-level `"logical_type_registry"` key from the JSON file entirely. + +Keep `semantic_registry` at the top level (used by `arrow_hasher`). It's no longer passed to `type_converter`. + +- [ ] **Step 5: Update `contexts/data/schemas/context_schema.json`** + +Remove `"logical_type_registry"` from `"required"` array and from `"properties"`. + +- [ ] **Step 6: Update `extension_types/__init__.py`** docstring to remove the `DataContext.logical_type_registry` access path reference. + +- [ ] **Step 7: Update `test_write_side_registration.py`** + +Update `_make_test_context` to not pass `logical_type_registry`: + +```python +def _make_test_context(registry: LogicalTypeRegistry) -> DataContext: + """Create a DataContext with a fresh converter bound to the given registry.""" + base_ctx = get_default_context() + fresh_converter = UniversalTypeConverter( + logical_type_registry=registry, + ) + return DataContext( + context_key="test", + version="test", + description="test", + type_converter=fresh_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + type_handler_registry=base_ctx.type_handler_registry, + # logical_type_registry removed from DataContext + ) +``` + +Also update the factory stub to use new protocol signatures: +```python +class _Factory: + def supports_class(self, python_type): # new method + return True + def reconstruct_from_arrow(self, name, storage, meta, converter): + return _make_logical_type(object) + def create_for_python_type(self, python_type, converter): # converter param + call_log.append(python_type) + return _make_logical_type(python_type) +``` + +And update `_make_logical_type` builtin logical type stubs to accept converter param: +```python +class _LT: + ... + def python_to_storage(self, v, converter=None): return str(v) + def storage_to_python(self, v, converter=None): return v +``` + +- [ ] **Step 8: Run tests related to contexts and write-side registration** + +```bash +uv run pytest tests/test_core/function_pod/test_write_side_registration.py -v 2>&1 | tail -30 +``` + +```bash +uv run pytest -v -k "context" 2>&1 | tail -20 +``` + +- [ ] **Step 9: Commit** + +```bash +git add src/orcapod/contexts/ tests/test_core/function_pod/test_write_side_registration.py +git commit -m "refactor(contexts): remove logical_type_registry from DataContext; move registry construction inside type_converter config" +``` + +--- + +## Task 11: Update `database_hooks.py` and `ExtensionAwareDatabase` + +**Files:** +- Modify: `src/orcapod/extension_types/database_hooks.py` +- Modify: `src/orcapod/databases/extension_aware_database.py` +- Modify: `tests/test_extension_types/test_database_hooks.py` + +- [ ] **Step 1: Update `register_discovered_extensions` in `database_hooks.py`** + +```python +"""Schema-walking utilities for extension type auto-registration and post-load casting.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from orcapod.extension_types.schema_walker import walk_schema + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + +logger = logging.getLogger(__name__) + + +def register_discovered_extensions( + converter: "UniversalTypeConverter | None", + schema: "pa.Schema", +) -> None: + """Register any extension types found in ``schema`` that are not yet known. + + Walks ``schema`` recursively via ``walk_schema`` to discover all Arrow extension + types at any nesting depth (both in-memory and field-metadata channels). + For each discovered type, delegates to ``converter._ensure_extension_type_info``. + + Args: + converter: The ``UniversalTypeConverter`` to use for registration. + If ``None``, this call is a no-op. + schema: The Arrow schema to inspect. + + Raises: + ValueError: Propagated from the converter if an extension type's metadata + has no registered factory or is malformed. + """ + if converter is None: + logger.debug("register_discovered_extensions: no converter provided, skipping") + return + + found = walk_schema(schema) + if not found: + logger.debug("register_discovered_extensions: no extension types in schema") + return + + logger.debug( + "register_discovered_extensions: found %d extension type(s): %s", + len(found), + [info.extension_name for info in found], + ) + for info in found: + # Bottom-up resolve the storage type first, then register the extension + resolved_storage = converter.register_storage_type(info.storage_type) + converter._ensure_extension_type_info( + info.extension_name, + info.extension_metadata, + resolved_storage, + ) + + +def apply_extension_types( + table: "pa.Table", + registry: "LogicalTypeRegistry", # keep registry param for now +) -> "pa.Table": + # (body unchanged — kept exactly as before) + ... +``` + +Keep the `apply_extension_types` and its helpers (`_apply_field`, etc.) exactly as they are — only `register_discovered_extensions` changes. + +Add the old `apply_extension_types` import back: +```python +from orcapod.extension_types.registry import LogicalTypeRegistry +``` + +- [ ] **Step 2: Update `ExtensionAwareDatabase`** + +```python +"""ExtensionAwareDatabase — wrapper that handles extension type registration.""" +from __future__ import annotations + +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.database_hooks import ( + apply_extension_types, + register_discovered_extensions, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +class ExtensionAwareDatabase: + """``ArrowDatabaseProtocol`` wrapper that auto-registers and applies extension types. + + Args: + db: Any ``ArrowDatabaseProtocol`` backend. + converter: The ``UniversalTypeConverter`` to use for extension type + registration and lookup. Callers typically supply + ``data_context.type_converter``. + """ + + def __init__( + self, + db: ArrowDatabaseProtocol, + converter: "UniversalTypeConverter", + ) -> None: + self._db = db + self._converter = converter + + def _process(self, table: "pa.Table | None") -> "pa.Table | None": + """Register extension types and re-wrap columns, or return None unchanged.""" + if table is None: + return None + register_discovered_extensions(self._converter, table.schema) + # apply_extension_types still needs the registry for column re-wrapping + registry = self._converter._logical_type_registry + if registry is not None: + return apply_extension_types(table, registry) + return table + + # All read/write methods delegate exactly as before, replacing self._registry + # usage with self._converter where needed in `at()`: + + def at(self, *path_components: str) -> "ExtensionAwareDatabase": + """Return a scoped view, preserving the extension-aware wrapper.""" + return ExtensionAwareDatabase( + self._db.at(*path_components), + converter=self._converter, + ) + + # ... (all other read/write methods are unchanged from before, just + # delegating self._process(self._db.method(...))) +``` + +Keep all the `get_record_by_id`, `get_all_records`, `add_record`, etc. methods unchanged except that `at()` now passes `converter=self._converter`. + +- [ ] **Step 3: Update call site where `ExtensionAwareDatabase` is constructed** + +Search for all places that construct `ExtensionAwareDatabase`: + +```bash +grep -r "ExtensionAwareDatabase" /home/kurouto/kurouto-jobs/7694626f-534d-48f5-b51f-4bb9c699d932/orcapod-python/src --include="*.py" -l +``` + +For each construction site, change `registry=data_context.logical_type_registry` to `converter=data_context.type_converter`. + +- [ ] **Step 4: Update `test_database_hooks.py`** + +The tests that use `register_discovered_extensions(registry, schema)` need to use `converter`: + +For each test: +1. Create a `UniversalTypeConverter` with the appropriate registry +2. Call `register_discovered_extensions(converter, schema)` instead of `register_discovered_extensions(registry, schema)` + +```python +# Example update pattern in test_database_hooks.py: + +# Before: +# register_discovered_extensions(registry, schema) + +# After: +from orcapod.semantic_types.universal_converter import UniversalTypeConverter +converter = UniversalTypeConverter(logical_type_registry=registry) +register_discovered_extensions(converter, schema) +``` + +- [ ] **Step 5: Run database hook tests** + +```bash +uv run pytest tests/test_extension_types/test_database_hooks.py -v 2>&1 | tail -30 +``` + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/database_hooks.py src/orcapod/databases/extension_aware_database.py tests/test_extension_types/test_database_hooks.py +git commit -m "refactor(database_hooks): register_discovered_extensions and ExtensionAwareDatabase now take converter instead of registry" +``` + +--- + +## Task 12: Remove `semantic_registry` from `UniversalTypeConverter`; delete `dataclass_encoding.py`; make `type_utils` private + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py` +- Delete: `src/orcapod/semantic_types/dataclass_encoding.py` +- Modify: `src/orcapod/extension_types/type_utils.py` +- Modify: `tests/test_semantic_types/test_universal_converter.py` + +- [ ] **Step 1: Remove `semantic_registry` param and usages from `UniversalTypeConverter`** + +In `__init__`, remove `semantic_registry` parameter and `self.semantic_registry = semantic_registry`. + +In `_convert_python_to_arrow`, remove: +```python +# Remove this block: +if self.semantic_registry: + converter = self.semantic_registry.get_converter_for_python_type(python_type) + if converter: + return converter.arrow_struct_type +``` + +In `_convert_arrow_to_python`, remove: +```python +# Remove these blocks: +if self.semantic_registry: + python_type = self.semantic_registry.get_python_type_for_semantic_struct_signature(arrow_type) + if python_type: + return python_type +``` + +In `_create_python_to_arrow_converter`, remove: +```python +# Remove: +if self.semantic_registry: + converter = self.semantic_registry.get_converter_for_python_type(python_type) + if converter: + return converter.python_to_struct_dict +``` + +In `_create_arrow_to_python_converter`, remove: +```python +# Remove: +if self.semantic_registry and pa.types.is_struct(arrow_type): + registered_python_type = ( + self.semantic_registry.get_python_type_for_semantic_struct_signature(arrow_type) + ) + if registered_python_type: + converter = self.semantic_registry.get_converter_for_python_type(registered_python_type) + if converter: + return converter.struct_dict_to_python +``` + +Remove the `from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry` import. + +- [ ] **Step 2: Remove `dataclass_encoding` imports and old dataclass path from converter** + +Remove all imports from `dataclass_encoding`: +```python +# Remove: +from orcapod.semantic_types.dataclass_encoding import ( + DATACLASS_TYPE_FIELD, + _get_type_hints_safe, + dataclass_to_arrow_struct_type, + dataclass_to_struct_dict, + has_dataclass_type_sentinel, + struct_dict_to_dataclass, +) +``` + +In `_convert_python_to_arrow`, remove the dataclass path: +```python +# Remove: +if dataclasses.is_dataclass(python_type) and isinstance(python_type, type): + return dataclass_to_arrow_struct_type(python_type, self) +``` + +In `_convert_arrow_to_python`, remove the dataclass sentinel path: +```python +# Remove the has_dataclass_type_sentinel block (lines referencing has_dataclass_type_sentinel, +# DATACLASS_TYPE_FIELD, struct_dict_to_dataclass, etc.) +``` + +In `_create_python_to_arrow_converter`, remove: +```python +# Remove: +if dataclasses.is_dataclass(python_type) and isinstance(python_type, type): + hints = _get_type_hints_safe(python_type) + field_converters = { + f.name: self.get_python_to_arrow_converter(hints[f.name]) + for f in dataclasses.fields(python_type) + if f.init + } + return lambda obj: dataclass_to_struct_dict(obj, field_converters) +``` + +In `_create_arrow_to_python_converter`, remove: +```python +# Remove the has_dataclass_type_sentinel block +``` + +Remove `import dataclasses` if it's now unused in the converter (check if still needed for the `_create_python_to_arrow_converter` logic after removal). + +- [ ] **Step 3: Delete `dataclass_encoding.py`** + +```bash +rm /home/kurouto/kurouto-jobs/7694626f-534d-48f5-b51f-4bb9c699d932/orcapod-python/src/orcapod/semantic_types/dataclass_encoding.py +git rm src/orcapod/semantic_types/dataclass_encoding.py +``` + +- [ ] **Step 4: Update `type_utils.py` to make `extract_leaf_classes` private** + +```python +# In src/orcapod/extension_types/type_utils.py: +# Rename extract_leaf_classes → _extract_leaf_classes +# Keep the old name as a shim if needed for other callers, or just rename. +``` + +Search for callers: +```bash +grep -r "extract_leaf_classes" /home/kurouto/kurouto-jobs/7694626f-534d-48f5-b51f-4bb9c699d932/orcapod-python/src --include="*.py" +``` + +The only caller was `ensure_types_registered_for_schemas` which we've already replaced with `register_python_class`. Rename the function: + +```python +def _extract_leaf_classes(annotation: Any) -> Iterator[type]: + # (body unchanged) +``` + +Update the module docstring to reflect it's now private. + +- [ ] **Step 5: Update any tests that import `extract_leaf_classes`** + +```bash +grep -r "extract_leaf_classes" /home/kurouto/kurouto-jobs/7694626f-534d-48f5-b51f-4bb9c699d932/orcapod-python/tests --include="*.py" +``` + +Update those tests to use `_extract_leaf_classes` (or remove if the function is no longer tested as part of the public API). + +- [ ] **Step 6: Remove test for `dataclass_encoding.py`** + +Since `dataclass_encoding.py` is deleted, the test file `tests/test_semantic_types/test_dataclass_encoding.py` will fail on import. Remove or archive it: + +```bash +git rm tests/test_semantic_types/test_dataclass_encoding.py +``` + +- [ ] **Step 7: Update `test_universal_converter.py` to not use `semantic_registry`** + +Find all places in `test_universal_converter.py` that pass `semantic_registry=...` to `UniversalTypeConverter(...)` and remove those calls. The tests should pass `logical_type_registry=...` instead (or no argument, using the default context). + +Also update the module-level `python_type_to_arrow_type`, `arrow_type_to_python_type`, `get_conversion_functions` module functions — they call `data_context.type_converter` which no longer uses semantic_registry for type dispatch. Path/UUID types should now go through the logical_type_registry. + +- [ ] **Step 8: Run full test suite** + +```bash +uv run pytest tests/test_semantic_types/ tests/test_extension_types/ -v 2>&1 | tail -40 +``` + +Fix any remaining failures. + +- [ ] **Step 9: Commit** + +```bash +git add -A +git commit -m "refactor(universal_converter): remove semantic_registry usage and dataclass_encoding imports; delete dataclass_encoding.py; make extract_leaf_classes private" +``` + +--- + +## Task 13: Full test suite verification + `extension_types/__init__.py` update + +**Files:** +- Modify: `src/orcapod/extension_types/__init__.py` +- Verify: entire test suite + +- [ ] **Step 1: Add `DataclassHandlerFactory` and `DataclassLogicalType` to `extension_types/__init__.py`** + +```python +from .dataclass_handler import DataclassHandlerFactory, DataclassLogicalType, DATACLASS_CATEGORY + +__all__ = [ + "LogicalTypeProtocol", + "LogicalTypeFactoryProtocol", + "TypeConverterProtocol", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "make_polars_extension_type", + "ExtensionTypeInfo", + "walk_schema", + "walk_field", + "register_discovered_extensions", + "apply_extension_types", + "DataclassLogicalType", + "DataclassHandlerFactory", + "DATACLASS_CATEGORY", +] +``` + +Update the module docstring to remove the `DataContext.logical_type_registry` access path. + +- [ ] **Step 2: Run the full test suite** + +```bash +uv run pytest tests/ -x 2>&1 | tail -50 +``` + +Fix all failures. Common issues: +- Tests constructing `DataContext` with `logical_type_registry=` → remove that arg +- Tests calling `data_context.logical_type_registry` → use `data_context.type_converter._logical_type_registry` or refactor to use converter methods +- Tests calling `get_default_logical_type_registry()` → use `get_default_context().type_converter._logical_type_registry` or use the converter's registration methods +- Tests calling `factory.create_for_python_type(t)` without `converter=` → add `converter=None` or a stub + +- [ ] **Step 3: Run full test suite and confirm it passes** + +```bash +uv run pytest tests/ 2>&1 | tail -20 +``` +Expected: All tests pass. + +- [ ] **Step 4: Final commit** + +```bash +git add src/orcapod/extension_types/__init__.py +git commit -m "feat(extension_types): export DataclassHandlerFactory, DataclassLogicalType, DATACLASS_CATEGORY" +``` + +--- + +## Self-Review + +### Spec Coverage Check + +| Spec section | Covered by task | +|---|---| +| `TypeConverterProtocol` added to `extension_types/protocols.py` | Task 1 | +| `LogicalTypeFactoryProtocol`: add `supports_class`, `converter` param | Task 1 | +| `LogicalTypeProtocol`: add `converter` param | Task 1 | +| Built-in types: add `converter` param (accept, ignore) | Task 2 | +| `register_python_class` on converter | Task 3 | +| `register_storage_type` on converter | Task 4 | +| `python_to_storage` / `storage_to_python` on converter | Task 5 | +| Registration pass-throughs | Task 5 | +| Update converter dispatch to pass `converter=self` | Task 5 | +| Simplify `ensure_types_registered_for_schemas` | Task 6 | +| Remove `ensure_*` from registry | Task 6 | +| `DataclassLogicalType` | Task 7 | +| `DataclassHandlerFactory` write path | Task 8 | +| `DataclassHandlerFactory` read path | Task 9 | +| `DataContext.logical_type_registry` removed | Task 10 | +| `get_default_logical_type_registry` removed | Task 10 | +| `v0.1.json` and `context_schema.json` updated | Task 10 | +| `register_discovered_extensions` takes converter | Task 11 | +| `ExtensionAwareDatabase` takes converter | Task 11 | +| Remove `semantic_registry` from converter | Task 12 | +| Delete `dataclass_encoding.py` | Task 12 | +| `extract_leaf_classes` made private | Task 12 | + +All spec requirements are covered. ✓ + +### Known Deviations from Spec + +1. **`register_discovered_extensions`**: The spec proposes simplifying to `for field in schema: converter.register_storage_type(field.type)`. The plan retains `walk_schema` to preserve support for the field-metadata channel (Parquet cold-start where `field.type` is a plain storage type, not a `pa.ExtensionType`). The spec's simplified version only handles in-memory extension types. + +2. **`apply_extension_types`**: Still takes a `LogicalTypeRegistry` argument. `ExtensionAwareDatabase` accesses it via `converter._logical_type_registry`. This is an internal implementation detail. diff --git a/superpowers/plans/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md b/superpowers/plans/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md new file mode 100644 index 00000000..91e76732 --- /dev/null +++ b/superpowers/plans/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md @@ -0,0 +1,831 @@ +# PLT-1720: register_python_class storage-type cleanup Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Make `register_python_class` and `register_storage_type` both return storage-safe Arrow types (extension type allowed at top level, no extension types nested inside struct/list fields), delete `_strip_ext_to_storage`, and fix `reconstruct_from_arrow` to register nested types so Parquet round-trips for nested dataclasses work in a fresh process. + +**Architecture:** Three coordinated changes: (1) `_register_python_class_impl` container branches strip any extension type returned by recursive calls before embedding in list/dict; (2) `register_storage_type` strips extension types from struct/list fields when rebuilding the type; (3) `DataclassLogicalTypeFactory.create_for_python_type` replaces the recursive `_strip_ext_to_storage` call with a one-liner, and `reconstruct_from_arrow` adds `converter.register_python_class(annotation)` per field to trigger nested registration. + +**Tech Stack:** Python 3.12+, PyArrow ≥ 20, `uv run pytest` + +--- + +## File map + +| File | Change | +|---|---| +| `src/orcapod/extension_types/protocols.py` | Docstring updates only | +| `src/orcapod/semantic_types/universal_converter.py` | `_register_python_class_impl` container branches; `register_storage_type` struct/list stripping | +| `src/orcapod/extension_types/dataclass_logical_type_factory.py` | Delete `_strip_ext_to_storage`; update `create_for_python_type`; update `reconstruct_from_arrow` | +| `DESIGN_ISSUES.md` | Mark ET1 in progress / update workaround note | +| `tests/test_semantic_types/test_universal_converter.py` | Fix `test_register_storage_type_nested_struct_with_extension` | +| `tests/test_extension_types/test_dataclass_logical_type_factory.py` | New tests: `test_reconstruct_from_arrow_registers_nested_types`, `test_nested_dataclass_parquet_roundtrip` | + +--- + +## Task 1: Update docstrings in protocols.py + +**Files:** +- Modify: `src/orcapod/extension_types/protocols.py:27-33` + +- [ ] **Step 1: Update `register_python_class` docstring** + +Replace lines 27–29: +```python + def register_python_class(self, annotation: Any) -> "pa.DataType": + """Traverse a Python annotation and return its Arrow type, registering as needed.""" + ... +``` +With: +```python + def register_python_class(self, annotation: Any) -> "pa.DataType": + """Traverse a Python annotation, register any logical types found, and return + the storage-safe Arrow type. + + The returned type may be a ``pa.ExtensionType`` at the top level for registered + classes (e.g. ``UUID`` → ``orcapod.uuid`` extension type), but struct fields and + list value types at any depth are always plain (non-extension) Arrow types. + + Args: + annotation: A Python type or generic alias (e.g. ``list[str]``, + ``Optional[uuid.UUID]``, a dataclass type). + + Returns: + A storage-safe ``pa.DataType``. May be ``pa.ExtensionType`` at the top level; + never contains nested extension types in struct/list fields. + """ + ... +``` + +- [ ] **Step 2: Update `register_storage_type` docstring** + +Replace lines 31–33: +```python + def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Traverse an Arrow type bottom-up, registering extension types, and return resolved type.""" + ... +``` +With: +```python + def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + """Traverse an Arrow type bottom-up, registering extension types, and return a + storage-safe type. + + The returned type may be a ``pa.ExtensionType`` at the top level, but struct fields + and list value types at any depth are always plain (non-extension) Arrow types. + This invariant makes the return value safe to use as a struct field or list element + type without further stripping. + + Args: + arrow_type: An Arrow type to traverse and register. + + Returns: + A storage-safe ``pa.DataType``. + """ + ... +``` + +- [ ] **Step 3: Run existing protocol tests to confirm no breakage** + +```bash +uv run pytest tests/test_extension_types/test_protocols.py -v +``` +Expected: all PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/extension_types/protocols.py +git commit -m "docs(extension-types): update register_python_class and register_storage_type docstrings for storage-safe contract" +``` + +--- + +## Task 2: Fix `register_storage_type` — strip extension types from struct/list fields + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py:441-468` +- Test: `tests/test_semantic_types/test_universal_converter.py` + +The current `register_storage_type` builds new struct/list types with the recursed field types, but does **not** strip an extension type before embedding it in a struct field or list value. Under the storage-safe contract it must strip. + +- [ ] **Step 1: Write the failing test first** + +In `tests/test_semantic_types/test_universal_converter.py`, locate `test_register_storage_type_nested_struct_with_extension` (around line 931). The test currently asserts the extension type is **preserved** in the struct field. Under the new contract it must be **stripped**. Change the last two assertions: + +```python +def test_register_storage_type_nested_struct_with_extension(): + """Extension type nested inside a struct field is stripped to storage type (ET1).""" + import json + import uuid as _u + + ext_name = f"test.nested.{_u.uuid4().hex[:8]}" + category = "test.nested" + metadata = json.dumps({"category": category}).encode() + ArrowExt = make_arrow_extension_type(ext_name, pa.large_string(), metadata=metadata) + PolarsExt = make_polars_extension_type(ext_name, pa.large_string()) + + class _LT: + logical_type_name = ext_name + python_type = str + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return PolarsExt() + def python_to_storage(self, v, c=None): return str(v) + def storage_to_python(self, v, c=None): return v + + class _Factory: + def supports_class(self, t): return False + def create_for_python_type(self, t, converter): pass + def reconstruct_from_arrow(self, name, storage_type, meta, converter): + return _LT() + + registry = _make_registry_with_builtins() + registry.register_logical_type_factory(_Factory(), category=category) + converter = _make_converter(registry) + + ext_instance = ArrowExt() + struct_with_ext = pa.struct([pa.field("id", pa.int64()), pa.field("tag", ext_instance)]) + result = converter.register_storage_type(struct_with_ext) + + assert pa.types.is_struct(result) + assert result.field("id").type == pa.int64() + # Storage-safe: extension type inside struct field is stripped to its storage type + assert result.field("tag").type == pa.large_string() + assert not isinstance(result.field("tag").type, pa.ExtensionType) + # Side effect: the extension type IS registered (check via registry) + assert converter._logical_type_registry.get_by_arrow_extension_name(ext_name) is not None +``` + +- [ ] **Step 2: Run the test to verify it fails** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py::test_register_storage_type_nested_struct_with_extension -v +``` +Expected: FAIL — test currently asserts `isinstance(result.field("tag").type, pa.ExtensionType)`. + +- [ ] **Step 3: Fix `register_storage_type` in `universal_converter.py`** + +Locate the struct branch (around line 443). Replace the struct and list branches: + +**Old struct branch (lines ~443–451):** +```python + # Struct type — recurse into each field, preserving field-level metadata + if pa.types.is_struct(arrow_type): + resolved_fields = [] + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + resolved_type = self.register_storage_type(field.type) + resolved_fields.append( + pa.field(field.name, resolved_type, nullable=field.nullable, metadata=field.metadata) + ) + return pa.struct(resolved_fields) +``` + +**New struct branch:** +```python + # Struct type — recurse into each field, preserving field-level metadata. + # Strip any extension type from field types before embedding (ET1: Arrow/Polars + # cannot construct arrays whose struct fields are pa.ExtensionType nodes). + if pa.types.is_struct(arrow_type): + resolved_fields = [] + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + resolved_type = self.register_storage_type(field.type) + if isinstance(resolved_type, pa.ExtensionType): + resolved_type = resolved_type.storage_type # strip: ET1 + resolved_fields.append( + pa.field(field.name, resolved_type, nullable=field.nullable, metadata=field.metadata) + ) + return pa.struct(resolved_fields) +``` + +**Old large_list branch (lines ~453–458):** +```python + # Large list type — preserve value field metadata (used by ARROW:extension:* channel) + if pa.types.is_large_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + return pa.large_list( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) +``` + +**New large_list branch:** +```python + # Large list type — preserve value field metadata (used by ARROW:extension:* channel). + # Strip any extension type from the value type before embedding (ET1). + if pa.types.is_large_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + if isinstance(resolved_value, pa.ExtensionType): + resolved_value = resolved_value.storage_type # strip: ET1 + return pa.large_list( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) +``` + +**Old list branch (lines ~461–466):** +```python + # List type + if pa.types.is_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + return pa.list_( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) +``` + +**New list branch:** +```python + # List type — strip any extension type from the value type (ET1). + if pa.types.is_list(arrow_type): + vf = arrow_type.value_field + resolved_value = self.register_storage_type(vf.type) + if isinstance(resolved_value, pa.ExtensionType): + resolved_value = resolved_value.storage_type # strip: ET1 + return pa.list_( + pa.field(vf.name, resolved_value, nullable=vf.nullable, metadata=vf.metadata) + ) +``` + +- [ ] **Step 4: Run the test to verify it passes** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py::test_register_storage_type_nested_struct_with_extension -v +``` +Expected: PASS + +- [ ] **Step 5: Run the full `register_storage_type` suite** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -k "register_storage_type" -v +``` +Expected: all PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py tests/test_semantic_types/test_universal_converter.py +git commit -m "fix(universal-converter): register_storage_type strips extension types from struct/list fields (ET1 storage-safe invariant)" +``` + +--- + +## Task 3: Fix `_register_python_class_impl` container branches — strip before embedding in list/dict + +**Files:** +- Modify: `src/orcapod/semantic_types/universal_converter.py:298-326` +- Test: `tests/test_semantic_types/test_universal_converter.py` + +Currently the list/set/dict branches call `self.register_python_class(...)` and embed the result directly in `pa.large_list(...)` or the dict struct. Since `register_python_class` may now return an extension type (e.g. `register_python_class(UUID)` → `orcapod.uuid`), the container branches must strip before embedding to maintain the storage-safe guarantee. + +Note: the registry-hit and factory-dispatch return sites (`return lt.get_arrow_extension_type()`) are **already correct** — they return the extension type directly (top-level extension is allowed). No change needed there. + +- [ ] **Step 1: Verify existing container tests pass before touching anything** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -k "register_python_class_list or register_python_class_dict or register_python_class_set" -v +``` +Expected: PASS + +- [ ] **Step 2: Write a new failing test for `list[UUID]` error behaviour** + +Add at the end of the `register_python_class` block in `tests/test_semantic_types/test_universal_converter.py`: + +```python +def test_register_python_class_list_of_uuid_raises(): + """list[UUID] raises ValueError: UUID is a logical type and cannot be preserved + inside a list value field (ET2 in DESIGN_ISSUES.md). Tracked in PLT-1732.""" + converter = _make_converter() + with pytest.raises(ValueError, match="PLT-1732"): + converter.register_python_class(list[_uuid_module.UUID]) + + +def test_register_python_class_dict_str_uuid_raises(): + """dict[str, UUID] raises ValueError: UUID is a logical type and cannot be preserved + inside a struct field (ET1/ET2 in DESIGN_ISSUES.md). Tracked in PLT-1732.""" + converter = _make_converter() + with pytest.raises(ValueError, match="PLT-1732"): + converter.register_python_class(dict[str, _uuid_module.UUID]) +``` + +- [ ] **Step 3: Run the new tests to verify they fail** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py::test_register_python_class_list_of_uuid_raises tests/test_semantic_types/test_universal_converter.py::test_register_python_class_dict_str_uuid_raises -v +``` +Expected: FAIL — the list/dict branches currently embed the extension type without raising. + +- [ ] **Step 4: Fix the container branches in `_register_python_class_impl`** + +Locate the list, set, and dict branches (lines ~297–325). Apply stripping after each recursive `register_python_class` call before embedding in a container: + +**Old list branch (lines ~297–304):** +```python + # list[T] → pa.large_list(T) + if origin is list: + if not args: + raise ValueError( + "Unparameterized 'list' is not supported. Use 'list[T]' with a concrete " + "element type (e.g. list[int], list[str])." + ) + return pa.large_list(self.register_python_class(args[0])) +``` + +**New list branch:** +```python + # list[T] → pa.large_list(T). Strip extension type from element (ET1: extension + # types cannot be nested inside list value types). + if origin is list: + if not args: + raise ValueError( + "Unparameterized 'list' is not supported. Use 'list[T]' with a concrete " + "element type (e.g. list[int], list[str])." + ) + inner = self.register_python_class(args[0]) + if isinstance(inner, pa.ExtensionType): + inner = inner.storage_type # strip: ET1 + return pa.large_list(inner) +``` + +**Old set branch (lines ~306–313):** +```python + # set[T] → pa.large_list(T) + if origin is set: + if not args: + raise ValueError( + "Unparameterized 'set' is not supported. Use 'set[T]' with a concrete " + "element type (e.g. set[int], set[str])." + ) + return pa.large_list(self.register_python_class(args[0])) +``` + +**New set branch:** +```python + # set[T] → pa.large_list(T). Strip extension type from element (ET1). + if origin is set: + if not args: + raise ValueError( + "Unparameterized 'set' is not supported. Use 'set[T]' with a concrete " + "element type (e.g. set[int], set[str])." + ) + inner = self.register_python_class(args[0]) + if isinstance(inner, pa.ExtensionType): + inner = inner.storage_type # strip: ET1 + return pa.large_list(inner) +``` + +**Old dict branch (lines ~315–325):** +```python + # dict[K, V] → pa.large_list(struct{key: K, value: V}) + if origin is dict: + if len(args) < 2: + raise ValueError( + "Unparameterized 'dict' is not supported. Use 'dict[K, V]' with concrete " + "key and value types (e.g. dict[str, int])." + ) + key_arrow = self.register_python_class(args[0]) + val_arrow = self.register_python_class(args[1]) + return pa.large_list( + pa.struct([pa.field("key", key_arrow), pa.field("value", val_arrow)]) + ) +``` + +**New dict branch:** +```python + # dict[K, V] → pa.large_list(struct{key: K, value: V}). + # Strip extension types from key and value before embedding in the struct (ET1). + if origin is dict: + if len(args) < 2: + raise ValueError( + "Unparameterized 'dict' is not supported. Use 'dict[K, V]' with concrete " + "key and value types (e.g. dict[str, int])." + ) + key_arrow = self.register_python_class(args[0]) + if isinstance(key_arrow, pa.ExtensionType): + key_arrow = key_arrow.storage_type # strip: ET1 + val_arrow = self.register_python_class(args[1]) + if isinstance(val_arrow, pa.ExtensionType): + val_arrow = val_arrow.storage_type # strip: ET1 + return pa.large_list( + pa.struct([pa.field("key", key_arrow), pa.field("value", val_arrow)]) + ) +``` + +- [ ] **Step 5: Run the new tests to verify they pass** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py::test_register_python_class_list_of_uuid_raises tests/test_semantic_types/test_universal_converter.py::test_register_python_class_dict_str_uuid_raises -v +``` +Expected: PASS + +- [ ] **Step 6: Run the full `register_python_class` suite** + +```bash +uv run pytest tests/test_semantic_types/test_universal_converter.py -k "register_python_class" -v +``` +Expected: all PASS + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/semantic_types/universal_converter.py tests/test_semantic_types/test_universal_converter.py +git commit -m "fix(universal-converter): strip extension types from list/dict container element types in register_python_class (ET1)" +``` + +--- + +## Task 4: Delete `_strip_ext_to_storage` and update `create_for_python_type` + +**Files:** +- Modify: `src/orcapod/extension_types/dataclass_logical_type_factory.py:45-315` +- Test: `tests/test_extension_types/test_dataclass_logical_type_factory.py` + +`_strip_ext_to_storage` (lines 45–90) is now redundant: `register_python_class` already returns a storage-safe type (no nested extension types). The `create_for_python_type` method should replace the recursive `_strip_ext_to_storage(arrow_type)` call with a one-liner strip. + +- [ ] **Step 1: Verify the dataclass factory write-path tests pass before touching anything** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_logical_type_factory.py -v +``` +Expected: all PASS + +- [ ] **Step 2: Delete `_strip_ext_to_storage` and update `create_for_python_type`** + +In `src/orcapod/extension_types/dataclass_logical_type_factory.py`: + +**Delete** the entire `_strip_ext_to_storage` function (lines 45–90, inclusive of docstring). + +**Old block in `create_for_python_type` (lines ~310–315):** +```python + annotation = hints.get(field.name, Any) + arrow_type = converter.register_python_class(annotation) + # Strip extension types from struct field types: pa.array cannot build a + # struct array when a field type is a pa.ExtensionType (see ET1 in + # DESIGN_ISSUES.md). Value conversion is annotation-driven so stripping is safe. + stripped_type = _strip_ext_to_storage(arrow_type) + arrow_fields.append(pa.field(field.name, stripped_type)) +``` + +**New block:** +```python + annotation = hints.get(field.name, Any) + arrow_type = converter.register_python_class(annotation) + # register_python_class returns a storage-safe type: may be extension at the + # top level, but struct fields are always plain. Strip the top-level extension + # type here before inserting into the struct (ET1; see DESIGN_ISSUES.md). + if isinstance(arrow_type, pa.ExtensionType): + arrow_type = arrow_type.storage_type + arrow_fields.append(pa.field(field.name, arrow_type)) +``` + +Also update the comment in `DataclassLogicalType.__init__` (lines ~138–141) that references `_strip_ext_to_storage`: + +**Old:** +```python + # ``storage_type`` is already stripped of nested extension types by + # ``DataclassLogicalTypeFactory.create_for_python_type`` (see ET1 in + # DESIGN_ISSUES.md). ``make_polars_extension_type`` and + # ``pa.array`` both require plain storage types inside structs. +``` + +**New:** +```python + # ``storage_type`` must not contain nested extension types (ET1 in DESIGN_ISSUES.md). + # ``DataclassLogicalTypeFactory.create_for_python_type`` and ``reconstruct_from_arrow`` + # both guarantee this by stripping any top-level extension type from each field's + # Arrow type before inserting it into the struct. +``` + +- [ ] **Step 3: Run the dataclass factory write-path tests to verify they still pass** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_logical_type_factory.py -v +``` +Expected: all PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/extension_types/dataclass_logical_type_factory.py +git commit -m "refactor(dataclass-factory): delete _strip_ext_to_storage, replace with one-liner in create_for_python_type" +``` + +--- + +## Task 5: Fix `reconstruct_from_arrow` — register nested types (read-path completeness fix) + +**Files:** +- Modify: `src/orcapod/extension_types/dataclass_logical_type_factory.py:367-372` +- Test: `tests/test_extension_types/test_dataclass_logical_type_factory.py` + +`reconstruct_from_arrow` currently builds `field_annotations` but never calls `converter.register_python_class` for each annotation. This means nested dataclass types (e.g. `Inner` inside `Outer`) are never registered on the read path, causing `ValueError("Unsupported Python type: Inner.")` in a fresh process. + +- [ ] **Step 1: Write the failing test** + +Add `test_reconstruct_from_arrow_registers_nested_types` to `tests/test_extension_types/test_dataclass_logical_type_factory.py`. This test requires module-level dataclasses. Add them after the existing module-level dataclass definitions (around line 177), before the "DataclassLogicalTypeFactory write-path tests" section: + +```python +@dataclasses.dataclass +class _InnerForRegistrationTest: + """Module-level inner dataclass for registration completeness test.""" + value: int + + +@dataclasses.dataclass +class _OuterForRegistrationTest: + """Module-level outer dataclass for registration completeness test.""" + inner: _InnerForRegistrationTest + label: str +``` + +Then add the test: + +```python +def test_reconstruct_from_arrow_registers_nested_types(): + """reconstruct_from_arrow for Outer must register Inner as a side effect.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + # Build the storage type for _OuterForRegistrationTest manually (as it would come + # from Parquet): outer struct with an inner struct field (Inner is stored as a struct, + # NOT as an extension type inside the struct field — that's the ET1 constraint). + inner_storage = pa.struct([pa.field("value", pa.int64())]) + outer_storage = pa.struct([ + pa.field("inner", inner_storage), + pa.field("label", pa.large_string()), + ]) + outer_fqcn = f"{_OuterForRegistrationTest.__module__}.{_OuterForRegistrationTest.__qualname__}" + inner_fqcn = f"{_InnerForRegistrationTest.__module__}.{_InnerForRegistrationTest.__qualname__}" + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + + # Inner is NOT pre-registered + assert converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is None + + # reconstruct_from_arrow for Outer should trigger registration of Inner as a side effect + lt = factory.reconstruct_from_arrow(outer_fqcn, outer_storage, {"category": "orcapod.dataclass"}, converter) + + # Inner must now be registered + assert converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is not None +``` + +- [ ] **Step 2: Run the test to verify it fails** + +```bash +uv run pytest "tests/test_extension_types/test_dataclass_logical_type_factory.py::test_reconstruct_from_arrow_registers_nested_types" -v +``` +Expected: FAIL — `_InnerForRegistrationTest` is not registered after `reconstruct_from_arrow`. + +- [ ] **Step 3: Fix `reconstruct_from_arrow` in `dataclass_logical_type_factory.py`** + +Locate the field-iteration loop inside `reconstruct_from_arrow` (around line 367): + +**Old:** +```python + field_annotations = [] + for field in dataclasses.fields(cls): + if not field.init: + continue + annotation = hints.get(field.name, Any) + field_annotations.append((field.name, annotation)) +``` + +**New:** +```python + field_annotations = [] + for field in dataclasses.fields(cls): + if not field.init: + continue + annotation = hints.get(field.name, Any) + # Register any logical type the field annotation maps to (registration + # completeness invariant: all nested logical types must be registered when + # the outer type is registered). The return value is discarded; only the + # side effect of registration matters here. + converter.register_python_class(annotation) + field_annotations.append((field.name, annotation)) +``` + +- [ ] **Step 4: Run the test to verify it passes** + +```bash +uv run pytest "tests/test_extension_types/test_dataclass_logical_type_factory.py::test_reconstruct_from_arrow_registers_nested_types" -v +``` +Expected: PASS + +- [ ] **Step 5: Run the full dataclass factory test suite** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_logical_type_factory.py -v +``` +Expected: all PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/dataclass_logical_type_factory.py tests/test_extension_types/test_dataclass_logical_type_factory.py +git commit -m "fix(dataclass-factory): reconstruct_from_arrow registers nested types (registration completeness invariant)" +``` + +--- + +## Task 6: Add Parquet round-trip test for nested dataclasses + +**Files:** +- Test: `tests/test_extension_types/test_dataclass_logical_type_factory.py` + +This test exercises the full fresh-process read path: write a nested dataclass to Parquet, read it back in a converter that has never seen the inner or outer type, call `register_discovered_extensions` + `apply_extension_types`, then convert back to Python. This is the end-to-end regression test for the bug fixed in Task 5. + +The two module-level dataclasses needed (`_InnerForRegistrationTest`, `_OuterForRegistrationTest`) were already added in Task 5. + +- [ ] **Step 1: Write the test** + +Add `test_nested_dataclass_parquet_roundtrip` to `tests/test_extension_types/test_dataclass_logical_type_factory.py`: + +```python +def test_nested_dataclass_parquet_roundtrip(tmp_path): + """Fresh-process Parquet round-trip for a two-level nested dataclass. + + Verifies that register_discovered_extensions triggers the chain: + register_arrow_extension("Outer") → reconstruct_from_arrow + → register_python_class(Inner) → registers Inner + so that storage_to_python can reconstruct the full nested object. + """ + import pyarrow.parquet as pq + from orcapod.extension_types.database_hooks import register_discovered_extensions, apply_extension_types + + # ── Write path ─────────────────────────────────────────────────────────── + write_converter = _make_full_converter() + + inner = _InnerForRegistrationTest(value=42) + outer = _OuterForRegistrationTest(inner=inner, label="hello") + + # Register Outer (which also registers Inner via create_for_python_type) + write_converter.register_python_class(_OuterForRegistrationTest) + + # Serialize to Arrow using python_schema_to_arrow_schema + python_dicts_to_arrow_table + outer_fqcn = f"{_OuterForRegistrationTest.__module__}.{_OuterForRegistrationTest.__qualname__}" + arrow_schema = write_converter.python_schema_to_arrow_schema({"item": _OuterForRegistrationTest}) + rows = [{"item": write_converter.python_to_storage(outer, _OuterForRegistrationTest)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "nested.parquet" + pq.write_table(table, parquet_path) + + # ── Read path (fresh converter — neither Inner nor Outer pre-registered) ── + read_converter = _make_full_converter() + read_table = pq.read_table(parquet_path) + + # register_discovered_extensions should trigger: Outer → reconstruct_from_arrow + # → register_python_class(Inner) → registers Inner + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + # Both types must now be registered + assert read_converter._logical_type_registry.get_by_python_type(_OuterForRegistrationTest) is not None + assert read_converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is not None + + # Convert back to Python + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + reconstructed = rows_out[0]["item"] + assert isinstance(reconstructed, _OuterForRegistrationTest) + assert isinstance(reconstructed.inner, _InnerForRegistrationTest) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" +``` + +- [ ] **Step 2: Run the test to verify it fails before the fix is in place** + +(This test should already pass since Task 5 fixed `reconstruct_from_arrow`. If running Tasks in order, it will pass. Run it now to confirm.) + +```bash +uv run pytest "tests/test_extension_types/test_dataclass_logical_type_factory.py::test_nested_dataclass_parquet_roundtrip" -v +``` +Expected: PASS (Task 5 already made this possible). + +- [ ] **Step 3: Run the full dataclass factory test suite to confirm no regressions** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_logical_type_factory.py -v +``` +Expected: all PASS + +- [ ] **Step 4: Commit** + +```bash +git add tests/test_extension_types/test_dataclass_logical_type_factory.py +git commit -m "test(dataclass-factory): add Parquet round-trip test for nested dataclasses" +``` + +--- + +## Task 7: Update `DESIGN_ISSUES.md` — mark ET1 workaround updated + +**Files:** +- Modify: `DESIGN_ISSUES.md` (ET1 entry, around line 1003) + +- [ ] **Step 1: Update ET1** + +Find the ET1 entry. The **Workaround** section currently references `dataclass_handler._strip_ext_to_storage()`. Update it to reflect that `_strip_ext_to_storage` is gone, replaced by the storage-safe contract on `register_python_class` and `register_storage_type`. + +Replace the **Workaround** paragraph in ET1: + +**Old:** +``` +**Workaround:** `dataclass_handler._strip_ext_to_storage()` recursively replaces all +`pa.ExtensionType` nodes with their plain storage types. This stripping is applied in +`DataclassHandlerFactory.create_for_python_type` when building the struct's field types — +so the stored Arrow schema (and thus the struct passed to `make_polars_extension_type` and +`pa.Table.from_pylist`) never contains nested extension types. The consequence is that the +schema for a dataclass extension column reports downgraded inner field types (e.g. +`large_binary` instead of `orcapod.uuid`). This is invisible through the normal conversion +path (all value conversion flows through `converter.storage_to_python`, which is +annotation-driven), but would mislead any code that directly introspects the raw Arrow +or Polars schema of a dataclass extension column's storage fields. + +**Also affects `pa.Table.from_pylist`:** the same restriction applies to PyArrow's +`pa.Table.from_pylist` (and `pa.array`) — neither can build an array from a struct type +whose fields are `pa.ExtensionType` nodes, for the same underlying reason. The stripping +in `create_for_python_type` fixes both issues simultaneously. +``` + +**New:** +``` +**Workaround:** `register_python_class` and `register_storage_type` both uphold a +*storage-safe* invariant: the returned type may be a `pa.ExtensionType` at the top level, +but struct fields and list value types at any depth are always plain (non-extension) types. +`DataclassLogicalTypeFactory.create_for_python_type` strips the top-level extension type +with a one-liner (`if isinstance(arrow_type, pa.ExtensionType): arrow_type = arrow_type.storage_type`) +before inserting it into the struct, so the struct passed to `make_polars_extension_type` +and `pa.Table.from_pylist` never contains nested extension types. The private +`_strip_ext_to_storage` recursive helper was removed in PLT-1720; the stripping is now +trivially correct because the storage-safe invariant guarantees `.storage_type` is always +already clean. + +**Also affects `pa.Table.from_pylist`:** the same restriction applies to PyArrow's +`pa.Table.from_pylist` (and `pa.array`) — neither can build an array from a struct type +whose fields are `pa.ExtensionType` nodes, for the same underlying reason. The stripping +in `create_for_python_type` fixes both issues simultaneously. +``` + +- [ ] **Step 2: Run the full test suite** + +```bash +uv run pytest tests/ -x -q +``` +Expected: all PASS + +- [ ] **Step 3: Commit** + +```bash +git add DESIGN_ISSUES.md +git commit -m "docs(design-issues): update ET1 workaround note to reflect removal of _strip_ext_to_storage (PLT-1720)" +``` + +--- + +## Task 8: Final verification and push + +- [ ] **Step 1: Run the complete test suite** + +```bash +uv run pytest tests/ -q +``` +Expected: all PASS, no failures, no errors + +- [ ] **Step 2: Verify the branch is on the right base** + +```bash +git log --oneline extension-type-system..HEAD +``` +Expected: 7 commits (Tasks 1–7), all on top of `extension-type-system`. + +- [ ] **Step 3: Push the branch** + +```bash +git push -u origin eywalker/plt-1720-cleanup-register_python_class-should-return-plain-storage +``` + +--- + +## Self-review checklist + +**Spec coverage:** +- ✅ `register_python_class` container branches strip extension types (Task 3) +- ✅ `register_storage_type` strips extension types from struct/list fields (Task 2) +- ✅ `_strip_ext_to_storage` deleted (Task 4) +- ✅ `create_for_python_type` uses one-liner strip (Task 4) +- ✅ `reconstruct_from_arrow` calls `register_python_class` per field (Task 5) +- ✅ Protocol docstrings updated (Task 1) +- ✅ `DESIGN_ISSUES.md` ET1 updated (Task 7) +- ✅ `test_register_storage_type_nested_struct_with_extension` updated (Task 2) +- ✅ `test_register_python_class_list_of_uuid_raises` added (Task 3) +- ✅ `test_reconstruct_from_arrow_registers_nested_types` added (Task 5) +- ✅ `test_nested_dataclass_parquet_roundtrip` added (Task 6) +- ✅ `database_hooks.py` unchanged (no task needed — already uses `register_storage_type` return value) +- ✅ Existing `register_python_class` tests (`_registry_hit_path`, `_uuid_registry_hit`, `_factory_dispatch`) — these already assert `isinstance(result, pa.ExtensionType)`, which is still correct under the storage-safe contract. No updates needed. + +**Type consistency:** All references to `register_python_class`, `register_storage_type`, `_strip_ext_to_storage`, `create_for_python_type`, and `reconstruct_from_arrow` use the same names as in the source files. + +**No placeholders:** Every step has explicit code or commands. diff --git a/superpowers/plans/2026-06-17-plt-1731-pydantic-logical-type-factory.md b/superpowers/plans/2026-06-17-plt-1731-pydantic-logical-type-factory.md new file mode 100644 index 00000000..27ff7e60 --- /dev/null +++ b/superpowers/plans/2026-06-17-plt-1731-pydantic-logical-type-factory.md @@ -0,0 +1,1380 @@ +# PLT-1731 Pydantic Logical Type Factory Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement `PydanticLogicalType` and `PydanticLogicalTypeFactory` for pydantic v2 `BaseModel` subclasses, following the same thin-leaf factory pattern as `DataclassLogicalTypeFactory`. + +**Architecture:** New `pydantic_logical_type_factory.py` mirrors `dataclass_logical_type_factory.py` with no cross-dependency between them. The FQCN walk loop shared by both factories is extracted into `type_utils._walk_fqcn`. Write path delegates all field-type resolution to `converter.register_python_class`; read path delegates to `converter.register_python_class` for the registration completeness invariant. Pydantic is an optional dependency — the factory is importable and `supports_class` returns `False` when pydantic is not installed. + +**Tech Stack:** Python 3.12, PyArrow, Polars, pydantic v2 (`model_fields`, `BaseModel`), `typing.get_type_hints` + +--- + +## File Map + +| File | Action | What changes | +|---|---|---| +| `pyproject.toml` | Modify | Add `pydantic = ["pydantic>=2.0"]` optional extra; add to `all` | +| `src/orcapod/extension_types/type_utils.py` | Modify | Add `_walk_fqcn` shared FQCN walk helper; update module docstring | +| `src/orcapod/extension_types/dataclass_logical_type_factory.py` | Modify | `_import_from_fqcn` delegates to `type_utils._walk_fqcn` | +| `src/orcapod/extension_types/pydantic_logical_type_factory.py` | **Create** | `PYDANTIC_CATEGORY`, `PydanticLogicalType`, `PydanticLogicalTypeFactory`, `_import_pydantic_model_from_fqcn` | +| `src/orcapod/extension_types/__init__.py` | Modify | Export `PYDANTIC_CATEGORY`, `PydanticLogicalType`, `PydanticLogicalTypeFactory` | +| `tests/test_extension_types/test_pydantic_logical_type_factory.py` | **Create** | Full test suite | +| `tests/test_extension_types/test_type_utils.py` | Modify | Add tests for `_walk_fqcn` | + +--- + +## Task 1: Add `pydantic` optional dependency + +**Files:** +- Modify: `pyproject.toml` + +- [ ] **Step 1: Add pydantic to optional extras** + +In `pyproject.toml`, find the `[project.optional-dependencies]` section. Add the `pydantic` entry and update `all` to include it: + +```toml +[project.optional-dependencies] +redis = ["redis>=6.2.0"] +ray = ["ray[default]==2.48.0", "ipywidgets>=8.1.7"] +postgresql = ["psycopg[binary]>=3.0"] +spiraldb = [ + "pyspiral>=0.11.0", +] +pydantic = ["pydantic>=2.0"] +all = ["orcapod[redis]", "orcapod[ray]", "orcapod[postgresql]", "orcapod[spiraldb]", "orcapod[pydantic]"] +``` + +- [ ] **Step 2: Install pydantic** + +```bash +uv sync --extra pydantic +``` + +Expected: pydantic installs without errors. + +- [ ] **Step 3: Verify pydantic is available** + +```bash +uv run python -c "import pydantic; print(pydantic.__version__)" +``` + +Expected: prints a version string starting with `2.`. + +- [ ] **Step 4: Commit** + +```bash +git add pyproject.toml +git commit -m "chore(deps): add pydantic>=2.0 as optional dependency" +``` + +--- + +## Task 2: Factor `_walk_fqcn` into `type_utils.py` + +**Files:** +- Modify: `src/orcapod/extension_types/type_utils.py` +- Modify: `src/orcapod/extension_types/dataclass_logical_type_factory.py` +- Modify: `tests/test_extension_types/test_type_utils.py` + +- [ ] **Step 1: Write failing tests for `_walk_fqcn`** + +Add to `tests/test_extension_types/test_type_utils.py`: + +```python +import dataclasses +import pytest + + +# ── _walk_fqcn tests ───────────────────────────────────────────────────────── + +def test_walk_fqcn_resolves_module_level_class(): + """_walk_fqcn resolves a top-level class from its FQCN.""" + from orcapod.extension_types.type_utils import _walk_fqcn + import pathlib + obj = _walk_fqcn("pathlib.Path") + assert obj is pathlib.Path + + +def test_walk_fqcn_resolves_nested_attribute(): + """_walk_fqcn walks nested attribute chains (e.g. module.Outer.Inner).""" + from orcapod.extension_types.type_utils import _walk_fqcn + import os.path + # os.path.join is a function reachable via attribute walk + obj = _walk_fqcn("os.path.join") + assert obj is os.path.join + + +def test_walk_fqcn_raises_import_error_on_bad_module(): + """_walk_fqcn raises ImportError when no module prefix can be imported.""" + from orcapod.extension_types.type_utils import _walk_fqcn + with pytest.raises(ImportError): + _walk_fqcn("nonexistent.module.NoSuchClass") + + +def test_walk_fqcn_raises_import_error_on_missing_attr(): + """_walk_fqcn raises ImportError when module exists but attribute does not.""" + from orcapod.extension_types.type_utils import _walk_fqcn + with pytest.raises(ImportError): + _walk_fqcn("pathlib.NoSuchClass") + + +def test_walk_fqcn_raises_import_error_on_single_part(): + """_walk_fqcn raises ImportError when FQCN has no module separator.""" + from orcapod.extension_types.type_utils import _walk_fqcn + with pytest.raises(ImportError): + _walk_fqcn("justname") +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_type_utils.py -k "walk_fqcn" -v +``` + +Expected: all 5 tests FAIL with `ImportError: cannot import name '_walk_fqcn'`. + +- [ ] **Step 3: Add `_walk_fqcn` to `type_utils.py`** + +Replace the full content of `src/orcapod/extension_types/type_utils.py` with: + +```python +"""Utility helpers for Python type annotation inspection and FQCN import. + +Used by the write-side registration trigger to extract leaf Python classes from +complex generic annotations like ``list[dict[A, list[B]]]``, and by logical type +factories to import classes from fully-qualified class names. +""" + +from __future__ import annotations + +import importlib +import typing +from typing import Any, Iterator + + +def _extract_leaf_classes(annotation: Any) -> Iterator[type]: + """Recursively yield all concrete leaf Python classes from a type annotation. + + Unwraps generic aliases (``list[T]``, ``dict[K, V]``, ``Optional[T]``, + ``Union[A, B]``, ``A | B``, etc.) using ``typing.get_origin`` and + ``typing.get_args`` and yields every non-generic leaf found. ``NoneType`` + that appears as a generic argument (from ``Optional`` and + ``Union[..., None]`` / ``T | None``) is skipped — callers see only the + concrete types. When ``type(None)`` is passed directly as the annotation, + it is yielded as-is. + + Non-type, non-generic values (e.g. unresolved string annotations) are + silently skipped. + + Args: + annotation: A Python type or generic alias to inspect. + + Yields: + Concrete Python ``type`` objects found at leaf positions. + + Examples: + >>> list(_extract_leaf_classes(list[int])) + [] + >>> set(_extract_leaf_classes(dict[str, list[MyClass]])) + {, } + """ + origin = typing.get_origin(annotation) + + if origin is None: + # Not a generic alias. Yield only if it is a plain type. + if isinstance(annotation, type): + yield annotation + return + + # Generic alias — recurse into every type argument, skipping NoneType. + for arg in typing.get_args(annotation): + if arg is type(None): + continue + yield from _extract_leaf_classes(arg) + + +def _walk_fqcn(fqcn: str) -> Any: + """Walk a fully-qualified class name and return the resolved object. + + Tries module prefixes from longest to shortest, then walks the remaining + parts as attribute accesses. For example: + + - ``"mypackage.sub.MyClass"`` → import ``mypackage.sub``, then + ``getattr(module, "MyClass")``. + - ``"mypackage.sub.Outer.Inner"`` → import ``mypackage.sub``, then + ``getattr(module, "Outer")``, then ``getattr(Outer, "Inner")``. + + Does **not** validate the type of the resolved object — callers are + responsible for checking that the result is the expected kind of object + (e.g. a dataclass, a ``BaseModel`` subclass). + + Args: + fqcn: Fully-qualified name, e.g. ``"mypackage.sub.MyClass"``. + + Returns: + The resolved Python object. + + Raises: + ImportError: If no valid module+attribute split can be found. + """ + parts = fqcn.split(".") + if len(parts) < 2: + raise ImportError(f"Cannot import from FQCN {fqcn!r}: no module separator found.") + + for i in range(len(parts) - 1, 0, -1): + module_path = ".".join(parts[:i]) + attr_parts = parts[i:] + try: + module = importlib.import_module(module_path) + except (ImportError, ModuleNotFoundError): + continue + obj: Any = module + try: + for attr in attr_parts: + obj = getattr(obj, attr) + except AttributeError: + continue + return obj + + raise ImportError( + f"Cannot import from FQCN {fqcn!r}: no valid module+attribute path found." + ) +``` + +- [ ] **Step 4: Run `_walk_fqcn` tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_type_utils.py -k "walk_fqcn" -v +``` + +Expected: all 5 tests PASS. + +- [ ] **Step 5: Update `_import_from_fqcn` in `dataclass_logical_type_factory.py` to delegate to `_walk_fqcn`** + +Replace the `_import_from_fqcn` function at the bottom of +`src/orcapod/extension_types/dataclass_logical_type_factory.py` with: + +```python +def _import_from_fqcn(fqcn: str) -> type: + """Import a dataclass from its fully-qualified class name. + + Delegates the module-prefix walk to ``type_utils._walk_fqcn``, then + validates the resolved object is a dataclass type. + + Args: + fqcn: Fully-qualified class name, e.g. ``"mypackage.sub.MyClass"``. + + Returns: + The imported dataclass type. + + Raises: + ImportError: If no valid module+attribute split can be found, or if the + resolved object is not a dataclass type. + """ + from orcapod.extension_types.type_utils import _walk_fqcn + + obj: Any = _walk_fqcn(fqcn) + if not dataclasses.is_dataclass(obj) or not isinstance(obj, type): + raise ImportError( + f"{fqcn!r} does not resolve to a dataclass type." + ) + return obj +``` + +Also remove the `import importlib` line at the top of the file since it is no longer used directly. + +- [ ] **Step 6: Run existing dataclass factory tests to verify no regression** + +```bash +uv run pytest tests/test_extension_types/test_dataclass_logical_type_factory.py -v +``` + +Expected: all tests PASS. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/extension_types/type_utils.py \ + src/orcapod/extension_types/dataclass_logical_type_factory.py \ + tests/test_extension_types/test_type_utils.py +git commit -m "refactor(type-utils): extract _walk_fqcn shared FQCN helper; delegate from _import_from_fqcn" +``` + +--- + +## Task 3: `PydanticLogicalType` + +**Files:** +- Create: `src/orcapod/extension_types/pydantic_logical_type_factory.py` +- Create: `tests/test_extension_types/test_pydantic_logical_type_factory.py` + +- [ ] **Step 1: Write failing tests for `PydanticLogicalType`** + +Create `tests/test_extension_types/test_pydantic_logical_type_factory.py`: + +```python +"""Tests for PydanticLogicalType and PydanticLogicalTypeFactory.""" + +from __future__ import annotations + +import uuid as _uuid_module +from typing import Any + +import pyarrow as pa +import pytest +from pydantic import BaseModel, PrivateAttr + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +class _StubConverter: + """Minimal converter stub for PydanticLogicalType tests.""" + + def python_to_storage(self, value, annotation): + if annotation is str: + return str(value) + if annotation is int: + return int(value) + return value + + def storage_to_python(self, storage_value, annotation): + if annotation is str: + return str(storage_value) + if annotation is int: + return int(storage_value) + return storage_value + + def register_python_class(self, annotation): + if annotation is str: + return pa.large_string() + if annotation is int: + return pa.int64() + raise ValueError(f"No mapping for {annotation}") + + +# ── PydanticLogicalType tests ──────────────────────────────────────────────── + +def test_pydantic_logical_type_is_importable(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + assert PydanticLogicalType is not None + + +def test_pydantic_logical_type_protocol_conformance(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + from orcapod.extension_types.protocols import LogicalTypeProtocol + + class _MyModel(BaseModel): + name: str + count: int + + storage = pa.struct([pa.field("name", pa.large_string()), pa.field("count", pa.int64())]) + lt = PydanticLogicalType( + logical_name="tests._MyModel", + python_type=_MyModel, + storage_type=storage, + field_annotations=[("name", str), ("count", int)], + ) + assert isinstance(lt, LogicalTypeProtocol) + + +def test_pydantic_logical_type_python_to_storage(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Point(BaseModel): + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = PydanticLogicalType("tests._Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.python_to_storage(_Point(x=3, y=7), converter) + assert result == {"x": 3, "y": 7} + + +def test_pydantic_logical_type_storage_to_python(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Point(BaseModel): + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = PydanticLogicalType("tests._Point2", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.storage_to_python({"x": 3, "y": 7}, converter) + assert isinstance(result, _Point) + assert result.x == 3 + assert result.y == 7 + + +def test_pydantic_logical_type_logical_type_name(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Foo(BaseModel): + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = PydanticLogicalType("mymod.Foo", _Foo, storage, [("val", str)]) + assert lt.logical_type_name == "mymod.Foo" + + +def test_pydantic_logical_type_python_type(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Bar(BaseModel): + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = PydanticLogicalType("mymod.Bar", _Bar, storage, [("val", str)]) + assert lt.python_type is _Bar + + +def test_python_to_storage_raises_when_converter_none(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _DC(BaseModel): + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = PydanticLogicalType("mymod._DC", _DC, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.python_to_storage(_DC(x=1), None) + + +def test_storage_to_python_raises_when_converter_none(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _DC2(BaseModel): + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = PydanticLogicalType("mymod._DC2", _DC2, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.storage_to_python({"x": 1}, None) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py -v +``` + +Expected: all tests FAIL with `ModuleNotFoundError: No module named 'orcapod.extension_types.pydantic_logical_type_factory'`. + +- [ ] **Step 3: Create `pydantic_logical_type_factory.py` with `PydanticLogicalType`** + +Create `src/orcapod/extension_types/pydantic_logical_type_factory.py`: + +```python +"""PydanticLogicalType and PydanticLogicalTypeFactory. + +Provides the ``PydanticLogicalType`` logical type implementation and the +``PydanticLogicalTypeFactory`` that synthesises and reconstructs +``PydanticLogicalType`` instances for pydantic v2 ``BaseModel`` subclasses. + +Write path (``create_for_python_type``): + Iterates model fields via ``model_fields`` (pydantic v2 API), delegates + field Arrow-type resolution to the converter via ``register_python_class``, + and returns a ``PydanticLogicalType`` backed by a ``pa.struct`` extension + type. + +Read path (``reconstruct_from_arrow``): + Imports the model by fully-qualified class name, resolves field annotations + against the (already bottom-up resolved) storage type, and returns a + ``PydanticLogicalType``. + +Category tag: ``"orcapod.pydantic"`` +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.extension_types.registry import make_arrow_extension_type, make_polars_extension_type +from orcapod.extension_types.type_utils import _walk_fqcn +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + from orcapod.extension_types.protocols import TypeConverterProtocol +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + +logger = logging.getLogger(__name__) + +#: Category tag embedded in Arrow extension metadata. Used as the factory dispatch key. +PYDANTIC_CATEGORY = "orcapod.pydantic" + + +class PydanticLogicalType: + """Logical type binding a pydantic ``BaseModel`` subclass to its Arrow extension type. + + Stores the model's fully-qualified class name as the Arrow extension name + and a ``pa.struct`` of the model fields as the storage type. + + No Arrow-type reasoning lives here — all field-type resolution is owned by + the converter and completed before this object is constructed. + + Args: + logical_name: Fully-qualified class name (e.g. ``"mymodule.sub.MyModel"``). + Used as both the logical type name and the Arrow extension name. + python_type: The pydantic ``BaseModel`` subclass. + storage_type: The Arrow ``pa.StructType`` for the model fields. + field_annotations: Ordered list of ``(field_name, python_annotation)`` + pairs matching the fields in ``storage_type``. + + Example: + >>> lt = PydanticLogicalType( + ... "mymod.Point", Point, + ... pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]), + ... [("x", int), ("y", int)], + ... ) + >>> lt.python_to_storage(Point(x=1, y=2), converter) + {"x": 1, "y": 2} + """ + + def __init__( + self, + logical_name: str, + python_type: type, + storage_type: pa.StructType, + field_annotations: list[tuple[str, Any]], + ) -> None: + self._logical_name = logical_name + self._python_type = python_type + self._storage_type = storage_type + self._field_annotations = field_annotations + + _metadata = json.dumps({"category": PYDANTIC_CATEGORY}).encode("utf-8") + self._arrow_ext_class = make_arrow_extension_type( + logical_name, storage_type, metadata=_metadata + ) + self._arrow_ext: pa.ExtensionType | None = None + # ``storage_type`` must not contain nested extension types (ET1 in DESIGN_ISSUES.md). + # On the write path, ``PydanticLogicalTypeFactory.create_for_python_type`` strips any + # top-level extension type from each field's Arrow type before inserting it into the + # struct. On the read path, ``reconstruct_from_arrow`` receives a ``storage_type`` + # already guaranteed storage-safe by ``register_storage_type``. + self._polars_ext_class = make_polars_extension_type(logical_name, storage_type) + self._polars_ext: pl.BaseExtension | None = None + + @property + def logical_type_name(self) -> str: + """Fully-qualified class name used as the logical type identifier.""" + return self._logical_name + + @property + def python_type(self) -> type: + """The pydantic ``BaseModel`` subclass this logical type represents.""" + return self._python_type + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for this model. + + Returns: + A cached ``pa.ExtensionType`` instance with ``extension_name`` equal to + the fully-qualified class name and ``storage_type`` equal to the struct + of the model fields. + """ + if self._arrow_ext is None: + self._arrow_ext = self._arrow_ext_class() + return self._arrow_ext + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return the Polars extension type for this model. + + Returns: + A cached ``pl.BaseExtension`` instance. + """ + if self._polars_ext is None: + self._polars_ext = self._polars_ext_class() + return self._polars_ext + + def python_to_storage(self, value: Any, converter: TypeConverterProtocol | None) -> dict[str, Any]: + """Convert a pydantic model instance to an Arrow-compatible struct dict. + + Iterates ``_field_annotations`` and delegates each field's conversion to + ``converter.python_to_storage``. + + Args: + value: A pydantic model instance of type ``python_type``. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A dict mapping field names to their Arrow storage values. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "PydanticLogicalType.python_to_storage requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + return { + name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations + } + + def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol | None) -> Any: + """Reconstruct a pydantic model instance from an Arrow struct dict. + + Args: + storage_value: A dict mapping field names to Arrow storage values. + converter: The active converter for per-field delegation. Must not be ``None``. + + Returns: + A pydantic model instance of type ``python_type``. Pydantic validation + runs on construction, ensuring the model is always in a valid state. + + Raises: + ValueError: If ``converter`` is ``None``. + """ + if converter is None: + raise ValueError( + "PydanticLogicalType.storage_to_python requires a converter — " + "pass a TypeConverterProtocol instance for field-level conversion." + ) + kwargs = { + name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations + } + return self._python_type(**kwargs) +``` + +- [ ] **Step 4: Run `PydanticLogicalType` tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py \ + -k "not factory" -v +``` + +Expected: all 8 `PydanticLogicalType` tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/pydantic_logical_type_factory.py \ + tests/test_extension_types/test_pydantic_logical_type_factory.py +git commit -m "feat(pydantic-factory): add PydanticLogicalType" +``` + +--- + +## Task 4: `PydanticLogicalTypeFactory` — write path + +**Files:** +- Modify: `src/orcapod/extension_types/pydantic_logical_type_factory.py` +- Modify: `tests/test_extension_types/test_pydantic_logical_type_factory.py` + +- [ ] **Step 1: Add module-level models and write-path tests to the test file** + +Append to `tests/test_extension_types/test_pydantic_logical_type_factory.py`: + +```python +# ── Module-level models for factory tests ──────────────────────────────────── +# Must be at module scope (not inside functions) so FQCN reconstruction works. + +class _FlatModel(BaseModel): + name: str + count: int + + +class _ModelWithUUID(BaseModel): + id: _uuid_module.UUID + label: str + + +class _ModelWithList(BaseModel): + tags: list[str] + count: int + + +class _ModelWithDict(BaseModel): + meta: dict[str, int] + + +class _InnerModel(BaseModel): + value: int + + +class _OuterModel(BaseModel): + inner: _InnerModel + label: str + + +class _ModelWithPrivateAttr(BaseModel): + name: str + _cache: str = PrivateAttr(default="") + + +# ── Factory helper ──────────────────────────────────────────────────────────── + +def _make_full_converter(): + """Make a UniversalTypeConverter with builtin types + PydanticLogicalTypeFactory.""" + from pydantic import BaseModel as _BaseModel + from orcapod.extension_types.builtin_logical_types import LogicalPath, LogicalUUID, LogicalUPath + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PYDANTIC_CATEGORY + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + registry = LogicalTypeRegistry(logical_types=[LogicalPath(), LogicalUUID(), LogicalUPath()]) + factory = PydanticLogicalTypeFactory() + registry.register_logical_type_factory(factory, category=PYDANTIC_CATEGORY, python_bases=[_BaseModel]) + return UniversalTypeConverter(logical_type_registry=registry) + + +# ── PydanticLogicalTypeFactory write-path tests ─────────────────────────────── + +def test_factory_supports_class_pydantic_model(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + assert factory.supports_class(_FlatModel) is True + + +def test_factory_supports_class_non_pydantic(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + import dataclasses + + @dataclasses.dataclass + class _DC: + x: int + + factory = PydanticLogicalTypeFactory() + assert factory.supports_class(str) is False + assert factory.supports_class(int) is False + assert factory.supports_class(_DC) is False + + +def test_factory_create_flat_model(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PydanticLogicalType + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_FlatModel, converter=converter) + + assert isinstance(lt, PydanticLogicalType) + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_struct(storage) + assert storage.field("name").type == pa.large_string() + assert storage.field("count").type == pa.int64() + + +def test_factory_create_model_with_uuid_field(): + """UUID field → plain storage type (large_binary) in the struct, not extension type (ET1).""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithUUID, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + id_field_type = storage.field("id").type + assert id_field_type == pa.large_binary() + assert not isinstance(id_field_type, pa.ExtensionType) + + +def test_factory_create_model_with_list_field(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithList, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_large_list(storage.field("tags").type) + assert storage.field("tags").type.value_type == pa.large_string() + + +def test_factory_create_model_with_dict_field(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithDict, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + meta_type = storage.field("meta").type + assert pa.types.is_large_list(meta_type) + assert pa.types.is_struct(meta_type.value_type) + field_names = {meta_type.value_type.field(i).name for i in range(meta_type.value_type.num_fields)} + assert field_names == {"key", "value"} + + +def test_factory_rejects_local_class(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + def _make_local(): + class _Local(BaseModel): + x: int + return _Local + + LocalModel = _make_local() + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + with pytest.raises(ValueError, match="local"): + factory.create_for_python_type(LocalModel, converter=converter) + + +def test_private_fields_not_stored(): + """Private attributes (PrivateAttr) must not appear in the Arrow struct.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithPrivateAttr, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + field_names = {storage.field(i).name for i in range(storage.num_fields)} + assert "name" in field_names + assert "_cache" not in field_names + assert storage.num_fields == 1 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py \ + -k "factory" -v 2>&1 | head -30 +``` + +Expected: all factory tests FAIL with `ImportError: cannot import name 'PydanticLogicalTypeFactory'`. + +- [ ] **Step 3: Add `PydanticLogicalTypeFactory` and `_import_pydantic_model_from_fqcn` to the module** + +Append to `src/orcapod/extension_types/pydantic_logical_type_factory.py`: + +```python + +class PydanticLogicalTypeFactory: + """Stateless factory that synthesises and reconstructs ``PydanticLogicalType`` instances. + + **Write path** (``create_for_python_type``): derives Arrow struct type from the + model fields by delegating to ``converter.register_python_class`` per field. + Only fields in ``model_fields`` are stored — computed fields and private + attributes are excluded. + + **Read path** (``reconstruct_from_arrow``): imports the model by FQCN, matches + fields against the already-resolved ``storage_type``, and returns a + ``PydanticLogicalType``. + + Category tag: ``"orcapod.pydantic"`` + + Register with:: + + from pydantic import BaseModel + converter.register_logical_type_factory( + PydanticLogicalTypeFactory(), + category="orcapod.pydantic", + python_bases=[BaseModel], + ) + + Example: + >>> factory = PydanticLogicalTypeFactory() + >>> factory.supports_class(MyModel) + True + >>> factory.supports_class(str) + False + """ + + def supports_class(self, python_type: type) -> bool: + """Return True if ``python_type`` is a pydantic ``BaseModel`` subclass. + + Args: + python_type: Any Python type. + + Returns: + True if pydantic is installed and ``python_type`` is a ``BaseModel`` + subclass. False if pydantic is not installed. + """ + try: + from pydantic import BaseModel + except ImportError: + return False + return isinstance(python_type, type) and issubclass(python_type, BaseModel) + + def create_for_python_type( + self, + python_type: type, + converter: TypeConverterProtocol, + ) -> PydanticLogicalType: + """Synthesise a ``PydanticLogicalType`` for a pydantic model (write path). + + Derives the FQCN, obtains type hints, and resolves each field's Arrow type + via ``converter.register_python_class``. Only fields present in + ``model_fields`` are stored — computed fields and private attributes are + excluded. Rejects local / unnamed classes. + + Args: + python_type: A pydantic ``BaseModel`` subclass. + converter: The active converter for field-type resolution. + + Returns: + A ``PydanticLogicalType`` ready for registration. + + Raises: + ValueError: If ``python_type`` is a local class (``__qualname__`` contains + ``""``). + """ + import typing + + fqcn = f"{python_type.__module__}.{python_type.__qualname__}" + if "" in fqcn: + raise ValueError( + f"Cannot register local class {python_type!r} as a PydanticLogicalType — " + f"local classes have no stable fully-qualified class name and cannot be " + f"reconstructed on read. Define the model at module level." + ) + + try: + hints = typing.get_type_hints(python_type) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {python_type!r}: {exc}" + ) from exc + + arrow_fields = [] + field_annotations = [] + for field_name in python_type.model_fields: + annotation = hints.get(field_name, Any) + arrow_type = converter.register_python_class(annotation) + # Strip top-level extension type before inserting into the struct (ET1; + # see DESIGN_ISSUES.md): Arrow cannot represent extension types inside + # struct field types. + if isinstance(arrow_type, pa.ExtensionType): + arrow_type = arrow_type.storage_type + arrow_fields.append(pa.field(field_name, arrow_type)) + field_annotations.append((field_name, annotation)) + + storage_type = pa.struct(arrow_fields) + logger.debug("PydanticLogicalTypeFactory: synthesised %r for %r", fqcn, python_type) + return PydanticLogicalType(fqcn, python_type, storage_type, field_annotations) + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + converter: TypeConverterProtocol, + ) -> PydanticLogicalType: + """Reconstruct a ``PydanticLogicalType`` from Arrow schema metadata (read path). + + Imports the model from its FQCN (``arrow_extension_name``), then matches + the model field annotations against the fields in ``storage_type``. + ``storage_type`` is already bottom-up resolved by ``register_storage_type`` + before this method is called. + + Args: + arrow_extension_name: FQCN of the pydantic model (Arrow extension name). + storage_type: Already-resolved ``pa.StructType`` for the model fields. + metadata: Full parsed metadata JSON dict (always contains ``"category"``). + converter: The active converter (used for registration completeness invariant). + + Returns: + A ``PydanticLogicalType`` ready for registration. + + Raises: + ImportError: If the class cannot be imported from ``arrow_extension_name``. + ValueError: If ``storage_type`` is not a struct type. + """ + import typing + + if not pa.types.is_struct(storage_type): + raise ValueError( + f"PydanticLogicalTypeFactory.reconstruct_from_arrow: expected a struct " + f"storage type for {arrow_extension_name!r}, got {storage_type!r}." + ) + + cls = _import_pydantic_model_from_fqcn(arrow_extension_name) + + try: + hints = typing.get_type_hints(cls) + except Exception as exc: + raise ValueError( + f"Cannot get type hints for {cls!r}: {exc}" + ) from exc + + field_annotations = [] + for field_name in cls.model_fields: + annotation = hints.get(field_name, Any) + # Register any logical type the field annotation maps to (registration + # completeness invariant: all nested logical types must be registered when + # the outer type is registered). The return value is discarded. + converter.register_python_class(annotation) + field_annotations.append((field_name, annotation)) + + logger.debug( + "PydanticLogicalTypeFactory: reconstructed %r from Arrow", arrow_extension_name + ) + return PydanticLogicalType( + arrow_extension_name, cls, storage_type, field_annotations + ) + + +def _import_pydantic_model_from_fqcn(fqcn: str) -> type: + """Import a pydantic ``BaseModel`` subclass from its fully-qualified class name. + + Delegates the module-prefix walk to ``type_utils._walk_fqcn``, then + validates the resolved object is a ``BaseModel`` subclass. + + Args: + fqcn: Fully-qualified class name, e.g. ``"mypackage.sub.MyModel"``. + + Returns: + The imported ``BaseModel`` subclass. + + Raises: + ImportError: If no valid module+attribute split can be found, or if the + resolved object is not a ``BaseModel`` subclass. + """ + from pydantic import BaseModel + + obj: Any = _walk_fqcn(fqcn) + if not (isinstance(obj, type) and issubclass(obj, BaseModel)): + raise ImportError( + f"{fqcn!r} does not resolve to a pydantic BaseModel subclass." + ) + return obj +``` + +- [ ] **Step 4: Run write-path tests to verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py \ + -k "factory" -v +``` + +Expected: all write-path factory tests PASS (reconstruct tests will still fail — that's fine for now). + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/pydantic_logical_type_factory.py \ + tests/test_extension_types/test_pydantic_logical_type_factory.py +git commit -m "feat(pydantic-factory): add PydanticLogicalTypeFactory write path" +``` + +--- + +## Task 5: Read path, round-trip tests, and Parquet integration + +**Files:** +- Modify: `tests/test_extension_types/test_pydantic_logical_type_factory.py` + +The `reconstruct_from_arrow` implementation is already in place from Task 4. This task adds the remaining tests that exercise the read path, value round-trips, and Parquet end-to-end. + +- [ ] **Step 1: Add module-level models for read-path and round-trip tests** + +Append to `tests/test_extension_types/test_pydantic_logical_type_factory.py` (after the write-path tests): + +```python +# ── Module-level models for read-path and round-trip tests ─────────────────── + +class _RoundTripPoint(BaseModel): + x: int + y: int + + +class _RoundTripRecord(BaseModel): + record_id: _uuid_module.UUID + label: str +``` + +- [ ] **Step 2: Add read-path and round-trip tests** + +Append to `tests/test_extension_types/test_pydantic_logical_type_factory.py`: + +```python +# ── PydanticLogicalTypeFactory read-path tests ──────────────────────────────── + +def test_factory_reconstruct_from_arrow(): + """reconstruct_from_arrow rebuilds the logical type from the Arrow struct.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PydanticLogicalType + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + metadata = {"category": "orcapod.pydantic"} + fqcn = f"{_RoundTripPoint.__module__}.{_RoundTripPoint.__qualname__}" + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.reconstruct_from_arrow(fqcn, storage, metadata, converter=converter) + + assert isinstance(lt, PydanticLogicalType) + assert lt.python_type is _RoundTripPoint + assert lt.logical_type_name == fqcn + + +def test_factory_reconstruct_from_arrow_invalid_fqcn(): + """ImportError if the FQCN cannot be resolved.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + storage = pa.struct([pa.field("x", pa.int64())]) + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + + with pytest.raises(ImportError): + factory.reconstruct_from_arrow( + "nonexistent.module.NoSuchModel", storage, {"category": "orcapod.pydantic"}, converter + ) + + +def test_reconstruct_from_arrow_registers_nested_types(): + """reconstruct_from_arrow for Outer must register Inner as a side effect.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + inner_storage = pa.struct([pa.field("value", pa.int64())]) + outer_storage = pa.struct([ + pa.field("inner", inner_storage), + pa.field("label", pa.large_string()), + ]) + outer_fqcn = f"{_OuterModel.__module__}.{_OuterModel.__qualname__}" + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + + # Inner is NOT pre-registered + assert converter._logical_type_registry.get_by_python_type(_InnerModel) is None + + factory.reconstruct_from_arrow(outer_fqcn, outer_storage, {"category": "orcapod.pydantic"}, converter) + + # Inner must now be registered as a side effect + assert converter._logical_type_registry.get_by_python_type(_InnerModel) is not None + + +# ── Value round-trip tests ──────────────────────────────────────────────────── + +def test_pydantic_python_to_storage_round_trip(): + """python_to_storage → storage_to_python returns an equivalent model.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + converter = _make_full_converter() + factory = PydanticLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripPoint, converter=converter) + converter.register_logical_type(lt) + + point = _RoundTripPoint(x=10, y=20) + storage_value = lt.python_to_storage(point, converter) + assert storage_value == {"x": 10, "y": 20} + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripPoint) + assert reconstructed.x == 10 + assert reconstructed.y == 20 + + +def test_pydantic_with_uuid_round_trip(): + """Round-trip a pydantic model with a UUID field.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + converter = _make_full_converter() + factory = PydanticLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripRecord, converter=converter) + converter.register_logical_type(lt) + + u = _uuid_module.UUID("12345678-1234-5678-1234-567812345678") + record = _RoundTripRecord(record_id=u, label="hello") + + storage_value = lt.python_to_storage(record, converter) + assert storage_value["label"] == "hello" + assert storage_value["record_id"] == u.bytes + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripRecord) + assert reconstructed.record_id == u + assert reconstructed.label == "hello" + + +# ── Parquet integration test ────────────────────────────────────────────────── + +def test_nested_pydantic_model_parquet_roundtrip(tmp_path): + """Fresh-process Parquet round-trip for a two-level nested pydantic model. + + Verifies that register_discovered_extensions triggers the chain: + register_arrow_extension("Outer") -> reconstruct_from_arrow + -> register_python_class(Inner) -> registers Inner + so that storage_to_python can reconstruct the full nested object. + """ + import pyarrow.parquet as pq + from orcapod.extension_types.database_hooks import register_discovered_extensions, apply_extension_types + + # ── Write path ─────────────────────────────────────────────────────────── + write_converter = _make_full_converter() + + inner = _InnerModel(value=42) + outer = _OuterModel(inner=inner, label="hello") + + write_converter.register_python_class(_OuterModel) + + arrow_schema = write_converter.python_schema_to_arrow_schema({"item": _OuterModel}) + rows = [{"item": outer}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "nested_pydantic.parquet" + pq.write_table(table, parquet_path) + + # ── Read path (fresh converter — neither Inner nor Outer pre-registered) ── + read_converter = _make_full_converter() + read_table = pq.read_table(parquet_path) + + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + assert read_converter._logical_type_registry.get_by_python_type(_OuterModel) is not None + assert read_converter._logical_type_registry.get_by_python_type(_InnerModel) is not None + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + reconstructed = rows_out[0]["item"] + assert isinstance(reconstructed, _OuterModel) + assert isinstance(reconstructed.inner, _InnerModel) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" +``` + +- [ ] **Step 3: Run all tests for the new factory** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py -v +``` + +Expected: all tests PASS. + +- [ ] **Step 4: Commit** + +```bash +git add tests/test_extension_types/test_pydantic_logical_type_factory.py +git commit -m "test(pydantic-factory): add read-path, round-trip, and Parquet integration tests" +``` + +--- + +## Task 6: Export from `__init__.py` and full test suite + +**Files:** +- Modify: `src/orcapod/extension_types/__init__.py` + +- [ ] **Step 1: Add exports to `__init__.py`** + +In `src/orcapod/extension_types/__init__.py`, add the pydantic import and update `__all__`: + +```python +from .pydantic_logical_type_factory import PYDANTIC_CATEGORY, PydanticLogicalType, PydanticLogicalTypeFactory +``` + +Add to `__all__`: + +```python + # PLT-1731 + "PYDANTIC_CATEGORY", + "PydanticLogicalType", + "PydanticLogicalTypeFactory", +``` + +The full updated `__init__.py` should be: + +```python +"""Arrow/Polars extension type system for orcapod. + +This subpackage provides the registry and protocol for logical types that map +between Python objects and their Arrow/Polars extension type representation. + +Built-in registrations (``LogicalPath``, ``LogicalUPath``, ``LogicalUUID``) are +wired into ``DataContext`` via ``contexts/data/v0.1.json``. Use +``get_default_context().type_converter.register_python_class()`` to register new +types, ``register_logical_type_factory()`` to add factories, and +``apply_extension_types()`` to re-wrap Arrow tables with their registered extension types. + +``DataclassLogicalTypeFactory`` provides automatic registration for Python dataclasses: +register it with a ``LogicalTypeRegistry`` and any dataclass used in a ``FunctionPod`` +will be auto-registered on pod declaration. + +``PydanticLogicalTypeFactory`` provides automatic registration for pydantic v2 +``BaseModel`` subclasses: register it with a ``LogicalTypeRegistry`` using +``python_bases=[BaseModel]`` and any model used in a ``FunctionPod`` will be +auto-registered on pod declaration. Requires the ``pydantic`` optional extra. +""" + +from __future__ import annotations + +from .protocols import LogicalTypeProtocol, LogicalTypeFactoryProtocol +from .registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema +from .database_hooks import apply_extension_types, register_discovered_extensions +from .dataclass_logical_type_factory import DATACLASS_CATEGORY, DataclassLogicalType, DataclassLogicalTypeFactory +from .pydantic_logical_type_factory import PYDANTIC_CATEGORY, PydanticLogicalType, PydanticLogicalTypeFactory + +__all__ = [ + "LogicalTypeProtocol", + "LogicalTypeFactoryProtocol", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "make_polars_extension_type", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", + # PLT-1655 + "register_discovered_extensions", + "apply_extension_types", + # PLT-1705 + "DATACLASS_CATEGORY", + "DataclassLogicalType", + "DataclassLogicalTypeFactory", + # PLT-1731 + "PYDANTIC_CATEGORY", + "PydanticLogicalType", + "PydanticLogicalTypeFactory", +] +``` + +- [ ] **Step 2: Verify the exports are importable** + +```bash +uv run python -c " +from orcapod.extension_types import ( + PYDANTIC_CATEGORY, PydanticLogicalType, PydanticLogicalTypeFactory +) +print('PYDANTIC_CATEGORY:', PYDANTIC_CATEGORY) +print('PydanticLogicalType:', PydanticLogicalType) +print('PydanticLogicalTypeFactory:', PydanticLogicalTypeFactory) +" +``` + +Expected output: +``` +PYDANTIC_CATEGORY: orcapod.pydantic +PydanticLogicalType: +PydanticLogicalTypeFactory: +``` + +- [ ] **Step 3: Run the full extension_types test suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all tests PASS with no regressions. + +- [ ] **Step 4: Run the full test suite** + +```bash +uv run pytest tests/ -x -q +``` + +Expected: all tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/extension_types/__init__.py +git commit -m "feat(pydantic-factory): export PydanticLogicalType and PydanticLogicalTypeFactory from extension_types" +``` diff --git a/superpowers/plans/2026-06-18-plt-1701-wire-factories-into-default-registry.md b/superpowers/plans/2026-06-18-plt-1701-wire-factories-into-default-registry.md new file mode 100644 index 00000000..334cb4e3 --- /dev/null +++ b/superpowers/plans/2026-06-18-plt-1701-wire-factories-into-default-registry.md @@ -0,0 +1,571 @@ +# PLT-1701: Wire Factories into Default LogicalTypeRegistry Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Wire `DataclassLogicalTypeFactory` and `PydanticLogicalTypeFactory` into the default `LogicalTypeRegistry` so dataclass- and pydantic-annotated pod fields are handled automatically with zero user-side setup. + +**Architecture:** Four targeted changes: (1) promote pydantic to a required dep, (2) harden `PydanticLogicalTypeFactory.supports_class` by dropping the `try/except ImportError` guard, (3) add a `factories` constructor parameter to `LogicalTypeRegistry` that calls `register_logical_type_factory` for each entry, (4) wire both factories into `v0.1.json`. A new test file verifies registry construction and default-context end-to-end behaviour. + +**Tech Stack:** Python 3.11+, PyArrow, pydantic v2, `uv` for dependency management. + +--- + +## File Map + +| File | Action | What changes | +|---|---|---| +| `pyproject.toml` | Modify | Move pydantic from optional to required dependency | +| `src/orcapod/extension_types/pydantic_logical_type_factory.py` | Modify | `supports_class`: drop `try/except ImportError`, import pydantic directly | +| `src/orcapod/extension_types/registry.py` | Modify | `LogicalTypeRegistry.__init__`: add `factories` parameter | +| `src/orcapod/contexts/data/v0.1.json` | Modify | Add `factories` list under `logical_type_registry._config` | +| `tests/test_extension_types/test_default_context_factories.py` | Create | Registry unit tests + default-context integration tests | + +--- + +## Task 0: Create and check out the feature branch + +**Files:** (none — git only) + +- [ ] **Step 1: Create and check out the branch from `extension-type-system`** + +```bash +git checkout extension-type-system +git checkout -b eywalker/plt-1701-wire-dataclasshandlerfactory-into-the-default +git branch --show-current +``` + +Expected: prints `eywalker/plt-1701-wire-dataclasshandlerfactory-into-the-default`. + +--- + +## Task 1: Promote pydantic to a required dependency + +**Files:** +- Modify: `pyproject.toml` + +- [ ] **Step 1: Move pydantic into `[project.dependencies]`** + +In `pyproject.toml`, add `"pydantic>=2.0"` to `[project.dependencies]` and remove the `pydantic` entry from `[project.optional-dependencies]` (keep the `all` extra but remove `"orcapod[pydantic]"` from it): + +```toml +[project] +dependencies = [ + "xxhash", + "networkx", + "typing_extensions", + "matplotlib>=3.10.3", + "pandas>=2.2.3", + "pyyaml>=6.0.2", + "pyarrow>=20.0.0", + "polars>=1.36.0", + "beartype>=0.21.0", + "deltalake>=1.0.2", + "graphviz>=0.21", + "gitpython>=3.1.45", + "universal-pathlib>=0.3.8", + "starfix>=0.2.0", + "pygraphviz>=1.14", + "tzdata>=2024.1", + "uuid-utils>=0.11.1", + "s3fs>=2025.12.0", + "pymongo>=4.15.5", + "basedpyright>=1.38.1", + "pydantic>=2.0", +] + +[project.optional-dependencies] +redis = ["redis>=6.2.0"] +ray = ["ray[default]==2.48.0", "ipywidgets>=8.1.7"] +postgresql = ["psycopg[binary]>=3.0"] +spiraldb = [ + "pyspiral>=0.11.0", +] +all = ["orcapod[redis]", "orcapod[ray]", "orcapod[postgresql]", "orcapod[spiraldb]"] +``` + +- [ ] **Step 2: Re-sync the environment** + +```bash +uv sync +``` + +Expected: pydantic is resolved as a required dep. No errors. + +- [ ] **Step 3: Verify pydantic is available** + +```bash +uv run python -c "import pydantic; print(pydantic.__version__)" +``` + +Expected: prints a version string starting with `2.`. + +- [ ] **Step 4: Run the existing pydantic factory tests to confirm nothing broke** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add pyproject.toml +git commit -m "chore(deps): promote pydantic to required dependency" +``` + +--- + +## Task 2: Harden `PydanticLogicalTypeFactory.supports_class` + +**Files:** +- Modify: `src/orcapod/extension_types/pydantic_logical_type_factory.py:211-225` + +The current `supports_class` wraps its pydantic import in a `try/except ImportError` that silently returns `False` when pydantic is absent. Now that pydantic is required, this guard is dead code and should be removed. The behaviour when pydantic IS installed is identical — no new failing test is needed; the existing `test_pydantic_logical_type_factory.py` suite covers it. + +- [ ] **Step 1: Update `supports_class` in `pydantic_logical_type_factory.py`** + +Replace the current `supports_class` method (lines ~211–225): + +```python +def supports_class(self, python_type: type) -> bool: + """Return True if ``python_type`` is a pydantic ``BaseModel`` subclass. + + Args: + python_type: Any Python type. + + Returns: + True if ``python_type`` is a ``BaseModel`` subclass. + """ + from pydantic import BaseModel + return isinstance(python_type, type) and issubclass(python_type, BaseModel) +``` + +- [ ] **Step 2: Verify pydantic tests still pass** + +```bash +uv run pytest tests/test_extension_types/test_pydantic_logical_type_factory.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/extension_types/pydantic_logical_type_factory.py +git commit -m "fix(pydantic-factory): drop try/except in supports_class — pydantic is now required" +``` + +--- + +## Task 3: Add `factories` parameter to `LogicalTypeRegistry.__init__` + +**Files:** +- Modify: `src/orcapod/extension_types/registry.py:205-213` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_extension_types/test_default_context_factories.py` with just the registry unit tests: + +```python +"""Tests for LogicalTypeRegistry factories parameter and default context factory wiring.""" + +from __future__ import annotations + +import dataclasses + +import pytest + +from orcapod.extension_types.dataclass_logical_type_factory import ( + DataclassLogicalTypeFactory, + DATACLASS_CATEGORY, +) +from orcapod.extension_types.pydantic_logical_type_factory import ( + PydanticLogicalTypeFactory, + PYDANTIC_CATEGORY, +) +from orcapod.extension_types.registry import LogicalTypeRegistry + + +# ── Module-level dataclasses (local classes cannot be registered) ──────────── + +@dataclasses.dataclass +class _SimplePoint: + x: int + y: int + + +# ── Registry constructor unit tests ───────────────────────────────────────── + +def test_registry_factories_param_registers_category(): + """factories param registers the factory under the given category.""" + factory = DataclassLogicalTypeFactory() + registry = LogicalTypeRegistry( + factories=[{"factory": factory, "category": DATACLASS_CATEGORY, "python_bases": [object]}] + ) + assert registry._category_factories.get(DATACLASS_CATEGORY) is factory + + +def test_registry_factories_param_registers_python_base(): + """factories param registers the factory under each python_base.""" + factory = DataclassLogicalTypeFactory() + registry = LogicalTypeRegistry( + factories=[{"factory": factory, "category": DATACLASS_CATEGORY, "python_bases": [object]}] + ) + assert registry._python_class_factories.get(object) is factory + + +def test_registry_factories_param_empty_list_is_noop(): + """factories=[] constructs successfully with no registered factories.""" + registry = LogicalTypeRegistry(factories=[]) + assert registry._category_factories == {} + assert registry._python_class_factories == {} + + +def test_registry_factories_param_none_is_noop(): + """factories=None (default) constructs successfully.""" + registry = LogicalTypeRegistry(factories=None) + assert registry._category_factories == {} +``` + +- [ ] **Step 2: Run the tests to confirm they fail** + +```bash +uv run pytest tests/test_extension_types/test_default_context_factories.py::test_registry_factories_param_registers_category -v +``` + +Expected: `FAILED` — `LogicalTypeRegistry.__init__` does not yet accept `factories`. + +- [ ] **Step 3: Update `LogicalTypeRegistry.__init__` in `registry.py`** + +Replace the current `__init__` signature and body (lines ~205–212): + +```python +def __init__( + self, + logical_types: list[LogicalTypeProtocol] | None = None, + factories: list[dict] | None = None, +) -> None: + self._by_logical_name: dict[str, LogicalTypeProtocol] = {} + self._by_arrow_name: dict[str, LogicalTypeProtocol] = {} + self._by_python_type: dict[type, LogicalTypeProtocol] = {} + self._category_factories: dict[str, LogicalTypeFactoryProtocol] = {} + self._python_class_factories: dict[type, LogicalTypeFactoryProtocol] = {} + for lt in (logical_types or []): + self.register_logical_type(lt) + for entry in (factories or []): + self.register_logical_type_factory( + entry["factory"], + category=entry.get("category"), + python_bases=entry.get("python_bases", []), + ) +``` + +- [ ] **Step 4: Run the registry unit tests** + +```bash +uv run pytest tests/test_extension_types/test_default_context_factories.py::test_registry_factories_param_registers_category tests/test_extension_types/test_default_context_factories.py::test_registry_factories_param_registers_python_base tests/test_extension_types/test_default_context_factories.py::test_registry_factories_param_empty_list_is_noop tests/test_extension_types/test_default_context_factories.py::test_registry_factories_param_none_is_noop -v +``` + +Expected: all 4 tests pass. + +- [ ] **Step 5: Run the existing registry tests to confirm no regressions** + +```bash +uv run pytest tests/test_extension_types/test_registry.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/extension_types/registry.py tests/test_extension_types/test_default_context_factories.py +git commit -m "feat(registry): add factories parameter to LogicalTypeRegistry.__init__" +``` + +--- + +## Task 4: Wire both factories into `v0.1.json` + +**Files:** +- Modify: `src/orcapod/contexts/data/v0.1.json` + +- [ ] **Step 1: Write the failing default-context tests** + +Append these tests to `tests/test_extension_types/test_default_context_factories.py`: + +```python +# ── Default context integration tests ──────────────────────────────────────── +# +# All tests use create_registry().get_context() — NOT get_default_context() — +# to avoid cross-test contamination via the global singleton cache. + +from orcapod.contexts import create_registry + + +def test_default_context_has_dataclass_factory(): + """Default context registers DataclassLogicalTypeFactory under orcapod.dataclass.""" + ctx = create_registry().get_context() + registry = ctx.type_converter._logical_type_registry + factory = registry._category_factories.get(DATACLASS_CATEGORY) + assert isinstance(factory, DataclassLogicalTypeFactory) + + +def test_default_context_has_pydantic_factory(): + """Default context registers PydanticLogicalTypeFactory under orcapod.pydantic.""" + ctx = create_registry().get_context() + registry = ctx.type_converter._logical_type_registry + factory = registry._category_factories.get(PYDANTIC_CATEGORY) + assert isinstance(factory, PydanticLogicalTypeFactory) +``` + +- [ ] **Step 2: Run those two tests to confirm they fail** + +```bash +uv run pytest tests/test_extension_types/test_default_context_factories.py::test_default_context_has_dataclass_factory tests/test_extension_types/test_default_context_factories.py::test_default_context_has_pydantic_factory -v +``` + +Expected: both `FAILED` — factories not yet in `v0.1.json`. + +- [ ] **Step 3: Add the `factories` list to `v0.1.json`** + +In `src/orcapod/contexts/data/v0.1.json`, find the `logical_type_registry` object spec +(under `type_converter._config`) and add `"factories"` alongside `"logical_types"`: + +```json +"logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ], + "factories": [ + { + "factory": { + "_class": "orcapod.extension_types.dataclass_logical_type_factory.DataclassLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.dataclass", + "python_bases": [{"_type": "builtins.object"}] + }, + { + "factory": { + "_class": "orcapod.extension_types.pydantic_logical_type_factory.PydanticLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.pydantic", + "python_bases": [{"_type": "pydantic.BaseModel"}] + } + ] + } +} +``` + +`{"_type": "builtins.object"}` resolves to the `object` class via `parse_objectspec`. +`{"_type": "pydantic.BaseModel"}` resolves to `pydantic.BaseModel` the same way — no +instance is created, the class itself is passed as a `python_bases` entry. + +- [ ] **Step 4: Run the default-context factory tests** + +```bash +uv run pytest tests/test_extension_types/test_default_context_factories.py::test_default_context_has_dataclass_factory tests/test_extension_types/test_default_context_factories.py::test_default_context_has_pydantic_factory -v +``` + +Expected: both pass. + +- [ ] **Step 5: Verify the existing context tests still pass** + +```bash +uv run pytest test-objective/unit/test_contexts.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/contexts/data/v0.1.json tests/test_extension_types/test_default_context_factories.py +git commit -m "feat(contexts): wire DataclassLogicalTypeFactory and PydanticLogicalTypeFactory into v0.1 default context" +``` + +--- + +## Task 5: Add end-to-end integration tests via the default context + +**Files:** +- Modify: `tests/test_extension_types/test_default_context_factories.py` + +These tests prove that a user can define a dataclass or pydantic model and use it immediately as a pod field type via the default context — no manual factory registration. + +- [ ] **Step 1: Add module-level pydantic model to the test file** + +At the top of `tests/test_extension_types/test_default_context_factories.py`, after the existing module-level dataclass, add: + +```python +from pydantic import BaseModel + + +class _SimpleModel(BaseModel): + name: str + score: float +``` + +- [ ] **Step 2: Add the auto-registration tests** + +Append to `tests/test_extension_types/test_default_context_factories.py`: + +```python +import pyarrow as pa +from orcapod.extension_types.database_hooks import apply_extension_types, register_discovered_extensions + + +def test_default_context_dataclass_auto_registered_on_use(): + """register_python_class on a dataclass works zero-setup via the default context.""" + converter = create_registry().get_context().type_converter + arrow_type = converter.register_python_class(_SimplePoint) + assert isinstance(arrow_type, pa.ExtensionType) + fqcn = f"{_SimplePoint.__module__}.{_SimplePoint.__qualname__}" + assert arrow_type.extension_name == fqcn + + +def test_default_context_pydantic_auto_registered_on_use(): + """register_python_class on a pydantic model works zero-setup via the default context.""" + converter = create_registry().get_context().type_converter + arrow_type = converter.register_python_class(_SimpleModel) + assert isinstance(arrow_type, pa.ExtensionType) + fqcn = f"{_SimpleModel.__module__}.{_SimpleModel.__qualname__}" + assert arrow_type.extension_name == fqcn +``` + +- [ ] **Step 3: Add the Parquet round-trip tests** + +Append to `tests/test_extension_types/test_default_context_factories.py`: + +```python +import pyarrow.parquet as pq + + +def test_default_context_dataclass_parquet_roundtrip(tmp_path): + """Dataclass round-trips through Parquet with no manual factory registration.""" + # Write path — fresh context, no manual factory setup + write_converter = create_registry().get_context().type_converter + arrow_schema = write_converter.python_schema_to_arrow_schema({"point": _SimplePoint}) + rows = [{"point": _SimplePoint(x=3, y=7)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "point.parquet" + pq.write_table(table, parquet_path) + + # Read path — another fresh context, no manual factory setup + read_converter = create_registry().get_context().type_converter + read_table = pq.read_table(parquet_path) + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + result = rows_out[0]["point"] + assert isinstance(result, _SimplePoint) + assert result.x == 3 + assert result.y == 7 + + +def test_default_context_pydantic_parquet_roundtrip(tmp_path): + """Pydantic model round-trips through Parquet with no manual factory registration.""" + # Write path — fresh context, no manual factory setup + write_converter = create_registry().get_context().type_converter + arrow_schema = write_converter.python_schema_to_arrow_schema({"model": _SimpleModel}) + rows = [{"model": _SimpleModel(name="alice", score=9.5)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "model.parquet" + pq.write_table(table, parquet_path) + + # Read path — another fresh context, no manual factory setup + read_converter = create_registry().get_context().type_converter + read_table = pq.read_table(parquet_path) + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + result = rows_out[0]["model"] + assert isinstance(result, _SimpleModel) + assert result.name == "alice" + assert result.score == 9.5 +``` + +- [ ] **Step 4: Run all tests in the new test file** + +```bash +uv run pytest tests/test_extension_types/test_default_context_factories.py -v +``` + +Expected: all tests pass. + +- [ ] **Step 5: Run the full extension_types test suite to check for regressions** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all tests pass (the existing xfail on `test_list_of_nested_dataclass_parquet_roundtrip` still xfails as expected). + +- [ ] **Step 6: Commit** + +```bash +git add tests/test_extension_types/test_default_context_factories.py +git commit -m "test(registry): add default context factory registration and Parquet round-trip tests" +``` + +--- + +## Task 6: Final verification and PR + +- [ ] **Step 1: Run the complete test suite** + +```bash +uv run pytest tests/ test-objective/ -v +``` + +Expected: all tests pass. No new failures. + +- [ ] **Step 2: Create the PR** + +```bash +gh pr create \ + --base extension-type-system \ + --title "feat(registry): wire DataclassLogicalTypeFactory and PydanticLogicalTypeFactory into default context (PLT-1701)" \ + --body "$(cat <<'EOF' +## Summary + +- Promotes pydantic to a required dependency (was optional extra) +- Adds `factories` parameter to `LogicalTypeRegistry.__init__` — accepts a list of dicts with `factory`, `category`, and `python_bases` keys; each entry is registered via `register_logical_type_factory` at construction time +- Drops `try/except ImportError` guard in `PydanticLogicalTypeFactory.supports_class` — pydantic is now always available +- Wires `DataclassLogicalTypeFactory` and `PydanticLogicalTypeFactory` into `v0.1.json` under `logical_type_registry._config.factories`; uses `{"_type": "..."}` object-specs for `python_bases` so `parse_objectspec` resolves them to actual type objects +- Adds integration tests verifying zero-setup dataclass/pydantic auto-registration and Parquet round-trips via the default context + +## Test plan + +- [ ] `uv run pytest tests/test_extension_types/ -v` — all pass +- [ ] `uv run pytest test-objective/unit/test_contexts.py -v` — all pass +- [ ] `uv run pytest tests/ test-objective/ -v` — full suite passes + +Closes PLT-1701 +EOF +)" +``` + +Expected: PR URL printed. Verify it targets `extension-type-system`, not `main`. diff --git a/superpowers/plans/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests.md b/superpowers/plans/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests.md new file mode 100644 index 00000000..0028fb2c --- /dev/null +++ b/superpowers/plans/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests.md @@ -0,0 +1,827 @@ +# PLT-1659: Extension Type Round-Trip Integration Tests — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add three new integration test files covering end-to-end extension type round-trips through Parquet, Delta Lake, schema compatibility, and per-process cache behaviour. + +**Architecture:** Three focused test files plus one source change and one docs update. Test files: `test_roundtrips.py` (write/read through Parquet and Delta backends), `test_schema_compatibility.py` (Arrow-level identity + Python-type-level compatibility), `test_cache_behavior.py` (registry cache populated and skipped on second read). SQLite backend is excluded from value round-trip tests because `SQLiteConnector` does not preserve `ARROW:extension:*` field metadata; that pattern is already covered by `test_extension_aware_database.py`. Source change: `ConnectorArrowDatabase.add_records()` gets a `ValueError` guard that rejects extension-typed columns (both in-memory `pa.ExtensionType` and metadata-only fields) as an interim safety measure while PLT-1795 is pending. + +**Tech Stack:** pytest, pyarrow, pyarrow.parquet, deltalake, polars, orcapod extension type APIs (`create_registry`, `UniversalTypeConverter`, `DataclassLogicalTypeFactory`), `unittest.mock.patch.object`. + +--- + +## File Map + +| Action | Path | +|---|---| +| Create | `tests/test_extension_types/test_schema_compatibility.py` | +| Create | `tests/test_extension_types/test_cache_behavior.py` | +| Create | `tests/test_extension_types/test_roundtrips.py` | +| Modify | `src/orcapod/databases/connector_arrow_database.py` — add `ValueError` guard in `add_records()` | +| Modify | `DESIGN_ISSUES.md` — add CA1 entry documenting SQL metadata loss and interim guard | + +--- + +## Task 1: Create and check out the feature branch + +**Files:** none (git only) + +- [ ] **Step 1: Verify you are on `extension-type-system`** + +```bash +git branch --show-current +``` + +Expected output: `extension-type-system` + +- [ ] **Step 2: Create and check out the feature branch** + +```bash +git checkout -b eywalker/plt-1659-integration-tests-end-to-end-semantic-type-round-trips +git branch --show-current +``` + +Expected output: `eywalker/plt-1659-integration-tests-end-to-end-semantic-type-round-trips` + +--- + +## Task 2: `test_schema_compatibility.py` + +**Files:** +- Create: `tests/test_extension_types/test_schema_compatibility.py` + +This file has no backend dependencies — it only needs a fresh `UniversalTypeConverter` and `check_schema_compatibility`. + +- [ ] **Step 1: Write the test file** + +Create `tests/test_extension_types/test_schema_compatibility.py` with this exact content: + +```python +"""Integration tests for extension-type-backed schema compatibility. + +Two complementary angles: + +Arrow-level identity + ``converter.python_schema_to_arrow_schema`` assigns each dataclass a unique + Arrow extension name derived from its fully-qualified class name. Two + dataclasses with identical struct shapes but different class names therefore + produce *different* extension names — the core identity guarantee of the + extension type system. + +Python-type-level compatibility + ``check_schema_compatibility`` from ``schema_utils`` uses beartype + ``is_subhint`` to compare Python type annotations. Same class → compatible; + different class with the same struct shape → incompatible. This is the + property that prevents silent data corruption when two unrelated dataclasses + happen to share the same fields. +""" +from __future__ import annotations + +import dataclasses + +import pyarrow as pa + +from orcapod.contexts import create_registry +from orcapod.types import Schema +from orcapod.utils.schema_utils import check_schema_compatibility + + +# Module-level dataclasses — DataclassLogicalTypeFactory rejects local classes +# because they have no stable fully-qualified class name for reconstruction. + +@dataclasses.dataclass +class _PointA: + x: int + y: int + + +@dataclasses.dataclass +class _PointB: + """Same struct shape as _PointA but a different class name.""" + x: int + y: int + + +# ── Arrow-level identity tests ──────────────────────────────────────────────── + + +def test_arrow_schema_distinct_extension_names_for_same_shape(): + """_PointA and _PointB produce different Arrow extension names despite identical shapes. + + This is the core identity guarantee: struct shape alone does not determine + type identity in the extension type system. + """ + converter_a = create_registry().get_context().type_converter + converter_b = create_registry().get_context().type_converter + + type_a = converter_a.register_python_class(_PointA) + type_b = converter_b.register_python_class(_PointB) + + assert isinstance(type_a, pa.ExtensionType) + assert isinstance(type_b, pa.ExtensionType) + + fqcn_a = f"{_PointA.__module__}.{_PointA.__qualname__}" + fqcn_b = f"{_PointB.__module__}.{_PointB.__qualname__}" + assert type_a.extension_name == fqcn_a + assert type_b.extension_name == fqcn_b + assert type_a.extension_name != type_b.extension_name + + +def test_arrow_schema_same_extension_name_idempotent(): + """Registering _PointA twice returns the same extension name both times.""" + converter = create_registry().get_context().type_converter + + type_first = converter.register_python_class(_PointA) + type_second = converter.register_python_class(_PointA) + + assert isinstance(type_first, pa.ExtensionType) + assert isinstance(type_second, pa.ExtensionType) + assert type_first.extension_name == type_second.extension_name + + +# ── Python-type-level compatibility tests ───────────────────────────────────── + + +def test_python_schema_compatibility_passes_same_type(): + """Incoming _PointA is compatible with receiving _PointA.""" + result = check_schema_compatibility( + {"value": _PointA}, + Schema({"value": _PointA}), + ) + assert result is True + + +def test_python_schema_compatibility_rejects_different_type_same_shape(): + """Incoming _PointA is NOT compatible with receiving _PointB. + + Both dataclasses share the same struct shape {x: int, y: int}, but they + are different Python types. The old shape-based system would have accepted + this silently; the extension type system correctly rejects it. + """ + result = check_schema_compatibility( + {"value": _PointA}, + Schema({"value": _PointB}), + ) + assert result is False +``` + +- [ ] **Step 2: Run the tests and verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_schema_compatibility.py -v +``` + +Expected: all 4 tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_schema_compatibility.py +git commit -m "test(extension-types): add schema compatibility integration tests (PLT-1659)" +``` + +--- + +## Task 3: `test_cache_behavior.py` + +**Files:** +- Create: `tests/test_extension_types/test_cache_behavior.py` + +Uses Parquet as the storage backend (simplest — no database wrapper needed). The second test patches `DataclassLogicalTypeFactory.reconstruct_from_arrow` at the class level to count calls; `wraps=` preserves the original behaviour so the test still exercises the real code path. + +- [ ] **Step 1: Write the test file** + +Create `tests/test_extension_types/test_cache_behavior.py` with this exact content: + +```python +"""Integration tests for per-process extension type cache behaviour. + +The ``LogicalTypeRegistry`` stores registered types in an in-memory dict keyed +by Arrow extension name. ``register_discovered_extensions`` skips the factory +call (``reconstruct_from_arrow``) when the extension name is already present in +the registry — this is the "cache hit" path. + +Two tests: + +1. ``test_cache_populated_after_first_read`` — verifies the type is absent from + a fresh converter's registry before reading a Parquet file, and present after. + +2. ``test_factory_not_called_on_second_read`` — verifies that ``reconstruct_from_arrow`` + is called exactly once (first read) and zero additional times on the second + read of the same file. +""" +from __future__ import annotations + +import dataclasses +from unittest.mock import patch + +import pyarrow as pa +import pyarrow.parquet as pq + +from orcapod.contexts import create_registry +from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + +# Module-level dataclass — local classes cannot be reconstructed from FQCN. + +@dataclasses.dataclass +class _CachePoint: + x: int + y: int + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _fresh_converter(): + """Return a fresh UniversalTypeConverter from a new registry instance. + + Uses ``create_registry()`` instead of ``get_default_context()`` to avoid + cross-test contamination through the global singleton cache. + """ + return create_registry().get_context().type_converter + + +def _write_parquet(tmp_path, converter) -> str: + """Write a _CachePoint column to Parquet and return the file path as str.""" + converter.register_python_class(_CachePoint) + arrow_schema = converter.python_schema_to_arrow_schema({"point": _CachePoint}) + rows = [{"point": _CachePoint(x=1, y=2)}] + table = converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + parquet_path = tmp_path / "cache_test.parquet" + pq.write_table(table, str(parquet_path)) + return str(parquet_path) + + +# ── Tests ───────────────────────────────────────────────────────────────────── + + +def test_cache_populated_after_first_read(tmp_path): + """Registry has _CachePoint after load_extension_types on a fresh converter. + + Before reading: the fresh converter's registry does not know about _CachePoint. + After reading: register_discovered_extensions triggers reconstruct_from_arrow + which registers _CachePoint, populating the cache. + """ + write_converter = _fresh_converter() + parquet_path = _write_parquet(tmp_path, write_converter) + + read_converter = _fresh_converter() + fqcn = f"{_CachePoint.__module__}.{_CachePoint.__qualname__}" + + # Before read: not registered + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn) is None + + read_converter.load_extension_types(pq.read_table(parquet_path)) + + # After read: registered (cache populated) + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn) is not None + + +def test_factory_not_called_on_second_read(tmp_path): + """reconstruct_from_arrow called once on first read, zero times on second read. + + On first read, register_discovered_extensions finds _CachePoint's extension + name in the schema, dispatches to the factory (call count = 1), and stores + the result in the registry. + + On second read, register_discovered_extensions finds the extension name already + in the registry and short-circuits — the factory is not called again + (call count remains 1). + """ + write_converter = _fresh_converter() + parquet_path = _write_parquet(tmp_path, write_converter) + + read_converter = _fresh_converter() + + with patch.object( + DataclassLogicalTypeFactory, + "reconstruct_from_arrow", + autospec=True, + wraps=DataclassLogicalTypeFactory.reconstruct_from_arrow, + ) as spy: + # First read: factory is called once + read_converter.load_extension_types(pq.read_table(parquet_path)) + assert spy.call_count == 1, f"Expected 1 factory call, got {spy.call_count}" + + # Second read on the same file: registry hit — factory not called again + read_converter.load_extension_types(pq.read_table(parquet_path)) + assert spy.call_count == 1, ( + f"Expected still 1 factory call after second read, got {spy.call_count}" + ) +``` + +- [ ] **Step 2: Run the tests and verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_cache_behavior.py -v +``` + +Expected: both tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_cache_behavior.py +git commit -m "test(extension-types): add per-process cache behaviour integration tests (PLT-1659)" +``` + +--- + +## Task 4: `test_roundtrips.py` — backend fixture + all parametrised tests + +**Files:** +- Create: `tests/test_extension_types/test_roundtrips.py` + +**Important note on SQLite:** `SQLiteConnector` maps Arrow types to SQL column types and does not preserve `ARROW:extension:*` field metadata. `ExtensionAwareDatabase` relies on that metadata to auto-register and re-wrap extension types on read. Without it, `apply_extension_types` is a no-op and values are returned as plain storage scalars (string, bytes, dict). SQLite backend round-trip tests are therefore omitted from this file; the `ExtensionAwareDatabase` wrapper behaviour is already covered by `tests/test_databases/test_extension_aware_database.py`. + +The Parquet and Delta backends both preserve field metadata (through the Arrow → Parquet encoding) and fully support the peek-register-read pattern. + +- [ ] **Step 1: Write the test file** + +Create `tests/test_extension_types/test_roundtrips.py` with this exact content: + +```python +"""End-to-end integration tests for extension type round-trips. + +Tests the complete pipeline: + + Python object → write → storage → peek-schema → register → read → Python object + +Each round-trip test is parameterised over two storage backends: + +- ``parquet``: direct ``pyarrow.parquet`` write/read. +- ``delta``: ``deltalake.write_deltalake`` / ``DeltaTable.to_pyarrow_dataset(as_large_types=True).to_table()``. + +SQLite (``ConnectorArrowDatabase`` + ``SQLiteConnector``) is excluded because +``SQLiteConnector`` maps Arrow types to SQL column types and discards +``ARROW:extension:*`` field metadata. Without that metadata, the +peek-register-read pattern cannot auto-register extension types on the read +path. The ``ExtensionAwareDatabase`` wrapper behaviour over SQLite is already +tested in ``tests/test_databases/test_extension_aware_database.py``. +""" +from __future__ import annotations + +import dataclasses +import pathlib +import uuid as uuid_module +from pathlib import Path +from typing import Callable + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from upath import UPath + +from orcapod.contexts import create_registry +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# ── Module-level dataclasses ────────────────────────────────────────────────── +# DataclassLogicalTypeFactory rejects local (in-function) classes because they +# have no stable fully-qualified class name for reconstruction from Arrow schema. + +@dataclasses.dataclass +class _PointA: + x: int + y: int + + +@dataclasses.dataclass +class _PointB: + """Same struct shape as _PointA, different class name.""" + x: int + y: int + + +@dataclasses.dataclass +class _Inner: + value: int + + +@dataclasses.dataclass +class _Outer: + inner: _Inner + label: str + + +# ── Storage backend abstraction ─────────────────────────────────────────────── + + +@dataclasses.dataclass +class _StorageBackend: + """Encapsulates backend-specific write and read logic for parameterised tests. + + Args: + name: Short identifier used in pytest test IDs (e.g. ``"parquet"``). + write: Callable that writes an Arrow table to a directory. + read: Callable that reads from that directory and returns an Arrow table + with extension types registered and applied. Must return only the + original user data columns (no ``__record_id`` or similar). + """ + name: str + write: Callable[[pa.Table, Path], None] + read: Callable[[Path, UniversalTypeConverter], pa.Table] + + +def _parquet_write(table: pa.Table, base_path: Path) -> None: + pq.write_table(table, str(base_path / "data.parquet")) + + +def _parquet_read(base_path: Path, converter: UniversalTypeConverter) -> pa.Table: + return converter.load_extension_types(pq.read_table(str(base_path / "data.parquet"))) + + +def _delta_write(table: pa.Table, base_path: Path) -> None: + import deltalake + deltalake.write_deltalake(str(base_path / "delta"), table) + + +def _delta_read(base_path: Path, converter: UniversalTypeConverter) -> pa.Table: + import deltalake + dt = deltalake.DeltaTable(str(base_path / "delta")) + # as_large_types=True preserves large_string / large_binary rather than + # normalising them to string / binary (Delta Lake's default behaviour). + raw = dt.to_pyarrow_dataset(as_large_types=True).to_table() + return converter.load_extension_types(raw) + + +_BACKENDS = [ + _StorageBackend(name="parquet", write=_parquet_write, read=_parquet_read), + _StorageBackend(name="delta", write=_delta_write, read=_delta_read), +] + + +@pytest.fixture(params=_BACKENDS, ids=lambda b: b.name) +def storage_backend(request: pytest.FixtureRequest) -> _StorageBackend: + """Yield one storage backend per parametrised run.""" + return request.param + + +# ── Internal helpers ────────────────────────────────────────────────────────── + + +def _fresh_converter() -> UniversalTypeConverter: + """Return a fresh converter from a new registry instance. + + Uses ``create_registry()`` instead of ``get_default_context()`` to avoid + cross-test contamination through the global singleton cache. + """ + return create_registry().get_context().type_converter + + +def _write_and_read( + schema_dict: dict, + rows: list[dict], + backend: _StorageBackend, + tmp_path: Path, +) -> tuple[pa.Table, UniversalTypeConverter]: + """Write rows with a fresh write converter and read back with a fresh read converter. + + Returns the resulting Arrow table (with extension types applied) and the + read-side converter (needed for ``arrow_table_to_python_dicts``). + """ + write_converter = _fresh_converter() + arrow_schema = write_converter.python_schema_to_arrow_schema(schema_dict) + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + backend.write(table, tmp_path) + + read_converter = _fresh_converter() + result = backend.read(tmp_path, read_converter) + return result, read_converter + + +# ── Built-in type round-trip tests ─────────────────────────────────────────── + + +def test_builtin_path_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """pathlib.Path round-trips through storage with extension name ``orcapod.path``. + + Built-in types (Path, UPath, UUID) are pre-registered in the default context + so the read-side converter already knows about them. The test verifies that: + + 1. The Arrow field carries the ``orcapod.path`` extension type after read. + 2. The Python value is reconstructed as a ``pathlib.Path`` instance. + """ + p = pathlib.Path("/tmp/orcapod/integration/test.txt") + result, read_converter = _write_and_read( + {"col": pathlib.Path}, + [{"col": p}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.path" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], pathlib.Path) + assert rows[0]["col"] == p + + +def test_builtin_upath_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """UPath round-trips through storage with extension name ``orcapod.upath``.""" + u = UPath("s3://my-bucket/data/file.parquet") + result, read_converter = _write_and_read( + {"col": UPath}, + [{"col": u}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.upath" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], UPath) + assert str(rows[0]["col"]) == str(u) + + +def test_builtin_uuid_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """uuid.UUID round-trips through storage with extension name ``orcapod.uuid``.""" + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result, read_converter = _write_and_read( + {"col": uuid_module.UUID}, + [{"col": u}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.uuid" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], uuid_module.UUID) + assert rows[0]["col"] == u + + +# ── Dataclass round-trip tests ──────────────────────────────────────────────── + + +def test_simple_dataclass_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """Simple dataclass round-trips with correct FQCN as the Arrow extension name. + + The read-side converter starts with no knowledge of _PointA. After read, + register_discovered_extensions triggers DataclassLogicalTypeFactory which + imports _PointA from its fully-qualified class name and registers it. + """ + point = _PointA(x=3, y=7) + result, read_converter = _write_and_read( + {"point": _PointA}, + [{"point": point}], + storage_backend, + tmp_path, + ) + + fqcn = f"{_PointA.__module__}.{_PointA.__qualname__}" + field = result.schema.field("point") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'point', got {field.type!r}" + ) + assert field.type.extension_name == fqcn + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + reconstructed = rows[0]["point"] + assert isinstance(reconstructed, _PointA) + assert reconstructed.x == 3 + assert reconstructed.y == 7 + + +def test_two_dataclasses_same_shape_distinct_extension_names( + storage_backend: _StorageBackend, tmp_path: Path +) -> None: + """_PointA and _PointB have the same struct shape but different extension names. + + Writing _PointA and reading it back must NOT reconstruct a _PointB, even + though their on-disk struct shapes (x: int, y: int) are identical. The + extension name (FQCN) is the sole identity signal. + """ + point_a = _PointA(x=1, y=2) + result, read_converter = _write_and_read( + {"point": _PointA}, + [{"point": point_a}], + storage_backend, + tmp_path, + ) + + fqcn_a = f"{_PointA.__module__}.{_PointA.__qualname__}" + fqcn_b = f"{_PointB.__module__}.{_PointB.__qualname__}" + + field = result.schema.field("point") + assert hasattr(field.type, "extension_name") + assert field.type.extension_name == fqcn_a + assert field.type.extension_name != fqcn_b # distinct from _PointB + + rows = read_converter.arrow_table_to_python_dicts(result) + reconstructed = rows[0]["point"] + assert isinstance(reconstructed, _PointA) + assert not isinstance(reconstructed, _PointB) + + +def test_nested_dataclass_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """Nested dataclass: _Outer and _Inner both registered; full object reconstructed. + + register_discovered_extensions triggers DataclassLogicalTypeFactory for _Outer. + That factory's reconstruct_from_arrow calls converter.register_python_class(_Inner) + as a side-effect, so _Inner is also registered without an explicit peek step. + """ + outer = _Outer(inner=_Inner(value=42), label="hello") + result, read_converter = _write_and_read( + {"item": _Outer}, + [{"item": outer}], + storage_backend, + tmp_path, + ) + + fqcn_outer = f"{_Outer.__module__}.{_Outer.__qualname__}" + fqcn_inner = f"{_Inner.__module__}.{_Inner.__qualname__}" + + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn_outer) is not None, ( + "_Outer should be registered after read" + ) + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn_inner) is not None, ( + "_Inner should be registered transitively after read" + ) + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + reconstructed = rows[0]["item"] + assert isinstance(reconstructed, _Outer) + assert isinstance(reconstructed.inner, _Inner) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" +``` + +- [ ] **Step 2: Run the tests and verify they pass** + +```bash +uv run pytest tests/test_extension_types/test_roundtrips.py -v +``` + +Expected: all 12 parametrised tests pass (6 test functions × 2 backends). + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_extension_types/test_roundtrips.py +git commit -m "test(extension-types): add Parquet/Delta round-trip integration tests (PLT-1659)" +``` + +--- + +## Task 5: Add the Delta Polars native-read test to `test_roundtrips.py` + +**Files:** +- Modify: `tests/test_extension_types/test_roundtrips.py` (append one function) + +This test reads a Delta table back via `pl.read_delta` (Polars' native Delta reader) rather than `DeltaTable.to_pyarrow_table()`, verifying that extension type metadata survives the Polars path. + +When the write-side converter calls `register_python_class(_PointA)`, it registers `_PointA` in both PyArrow's and Polars' **global** registries (as a side-effect of `registry.register_logical_type`). That global registration persists for the duration of the test process, so `pl.read_delta` can resolve `_PointA`'s extension type when reading the underlying Parquet files. + +- [ ] **Step 1: Append the Delta Polars test to `test_roundtrips.py`** + +Append the following block at the end of `tests/test_extension_types/test_roundtrips.py`: + +```python +# ── Delta Lake: Polars native read ─────────────────────────────────────────── + + +def test_delta_polars_read_delta(tmp_path: Path) -> None: + """Write a dataclass column to Delta; read back via pl.read_delta; extension type survives. + + The write-side converter registers _PointA in both PyArrow's and Polars' + global registries (``register_python_class`` calls ``make_polars_extension_type`` + which registers with Polars). ``pl.read_delta`` can therefore decode the column + as the correct Polars extension type, not a plain ``Struct``. + + Note: ``pl.DataFrame.to_arrow()`` exports Polars extension types as PyArrow + extension arrays but with empty serialized bytes (Polars does not forward + ``__arrow_ext_metadata__`` through its Arrow export). Python-object + reconstruction via the Polars-to-Arrow path is therefore not possible; that + path is tested by the separate ``parquet`` / ``delta`` parametrised tests + which read underlying Parquet files directly. + """ + import deltalake + import polars as pl + + delta_path = str(tmp_path / "polars_delta") + fqcn = f"{_PointA.__module__}.{_PointA.__qualname__}" + + # Write — registers _PointA in PyArrow + Polars global registries. + write_converter = _fresh_converter() + write_converter.register_python_class(_PointA) + arrow_schema = write_converter.python_schema_to_arrow_schema({"point": _PointA}) + rows = [{"point": _PointA(x=5, y=9)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + deltalake.write_deltalake(delta_path, table) + + # Read via Polars native Delta reader. + # _PointA is already in the Polars global registry from the write step above. + df = pl.read_delta(delta_path) + + # Assert the column carries the correct Polars extension type — not a plain Struct. + col_dtype = df.dtypes[0] + assert col_dtype.is_extension(), ( + f"Expected a Polars extension type on column 'point', got {col_dtype!r}" + ) + assert col_dtype.ext_name() == fqcn, ( + f"Expected extension name {fqcn!r}, got {col_dtype.ext_name()!r}" + ) +``` + +- [ ] **Step 2: Run the new test to verify it passes** + +```bash +uv run pytest tests/test_extension_types/test_roundtrips.py::test_delta_polars_read_delta -v +``` + +Expected: 1 test passes. + +- [ ] **Step 3: Run the full roundtrips file to confirm no regressions** + +```bash +uv run pytest tests/test_extension_types/test_roundtrips.py -v +``` + +Expected: 13 tests pass (12 from Task 4 + 1 new). + +- [ ] **Step 4: Commit** + +```bash +git add tests/test_extension_types/test_roundtrips.py +git commit -m "test(extension-types): add Delta Polars native-read round-trip test (PLT-1659)" +``` + +--- + +## Task 6: Full test run and PR + +**Files:** none + +- [ ] **Step 1: Run the full extension-types test suite** + +```bash +uv run pytest tests/test_extension_types/ -v +``` + +Expected: all tests pass. The three new files contribute 17 tests: +- `test_schema_compatibility.py`: 4 tests +- `test_cache_behavior.py`: 2 tests +- `test_roundtrips.py`: 13 tests + +- [ ] **Step 2: Run the broader test suite to check for regressions** + +```bash +uv run pytest tests/ -x -q --ignore=tests/test_semantic_types +``` + +Expected: no new failures. (`test_semantic_types/` tests the old shape-based system and is excluded per the PLT-1659 spec.) + +- [ ] **Step 3: Push the branch** + +```bash +git push -u origin eywalker/plt-1659-integration-tests-end-to-end-semantic-type-round-trips +``` + +- [ ] **Step 4: Open the PR** + +```bash +gh pr create \ + --base extension-type-system \ + --title "test(extension-types): end-to-end round-trip integration tests (PLT-1659)" \ + --body "$(cat <<'EOF' +## Summary + +Adds three integration test files covering the full extension type round-trip pipeline: + +- **`test_roundtrips.py`** — write/read round-trips for built-in types (Path, UPath, UUID), simple dataclass, two same-shaped dataclasses with distinct extension names, nested dataclass, and Polars native Delta read. Parameterised over Parquet and Delta backends. +- **`test_schema_compatibility.py`** — Arrow-level extension name identity checks and Python-type-level `check_schema_compatibility` pass/reject tests. +- **`test_cache_behavior.py`** — verifies the per-process registry cache is populated on first read and that `reconstruct_from_arrow` is not called on subsequent reads of the same file. + +## Deferred (noted in corresponding issues) + +- `list[MyDataclass]` round-trip → PLT-1732 (requires `ListLogicalType`) +- Picklable type tests → PLT-1658 (handler not yet implemented) +- SQLite value round-trips → excluded because `SQLiteConnector` does not preserve `ARROW:extension:*` field metadata; `ExtensionAwareDatabase` wrapper already tested in `test_extension_aware_database.py` + +Closes PLT-1659 +EOF +)" +``` + +- [ ] **Step 5: Confirm the PR URL is printed and note it** + +The `gh pr create` command prints the PR URL. Record it for tracking. diff --git a/superpowers/plans/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md b/superpowers/plans/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md new file mode 100644 index 00000000..a4642fdb --- /dev/null +++ b/superpowers/plans/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md @@ -0,0 +1,2466 @@ +# PLT-1660: Hard Cut Extension Type Hashing — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Delete the old shape-based `SemanticTypeRegistry` system, wire the new extension-type system into Arrow hashing, and rename all protocol/registry/handler classes to cleaner names. + +**Architecture:** `ArrowTypeDataVisitor` gains a `visit_extension()` hook (default: passthrough). `SemanticHashingVisitor` overrides it: for extension types whose Python counterpart has a registered semantic hasher, it converts the value to a Python object, hashes it, and stores the result as `pa.large_binary()` in the format `:::`. Unrecognized extension types pass through unmodified — starfix still sees their full metadata. All `TypeHandlerProtocol.handle()->Any` handlers are tightened to `PythonTypeSemanticHasherProtocol.hash()->ContentHash`. + +**Tech Stack:** Python 3.10+, PyArrow extension types, starfix-python, uv/pytest + +--- + +## File Map + +**Modified source:** +- `src/orcapod/protocols/hashing_protocols.py` — rename `TypeHandlerProtocol`→`PythonTypeSemanticHasherProtocol`, `handle()`→`hash()->ContentHash`; rename `type_handler_registry`→`type_semantic_hasher_registry` on `SemanticHasherProtocol` +- `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` — rename class + all methods +- `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` — rename 11 handler classes; `handle()`→`hash()->ContentHash`; rename `register_builtin_handlers` +- `src/orcapod/hashing/semantic_hashing/semantic_hasher.py` — rename `BaseSemanticHasher`→`SemanticAwarePythonHasher`; simplify dispatch; rename property +- `src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py` — update import + type annotations +- `src/orcapod/hashing/semantic_hashing/__init__.py` — update exports +- `src/orcapod/hashing/__init__.py` — update exports +- `src/orcapod/hashing/defaults.py` — rename function; update property access; remove broken `set_cacher` call +- `src/orcapod/hashing/visitors.py` — add `visit_extension` to base class + rewrite `SemanticHashingVisitor` +- `src/orcapod/hashing/arrow_hashers.py` — update `StarfixArrowHasher` constructor + short-circuit; delete `SemanticArrowHasher` +- `src/orcapod/hashing/versioned_hashers.py` — source `StarfixArrowHasher` from context; rename imports +- `src/orcapod/contexts/data/v0.1.json` — reorder components; remove `semantic_registry`; update class names and refs; add `type_converter`+`semantic_hasher` to `arrow_hasher`; remove `pa.Table` handlers (cycle-break) +- `src/orcapod/contexts/data/schemas/context_schema.json` — remove `semantic_registry` property; rename `type_handler_registry`→`python_type_semantic_hasher_registry` +- `src/orcapod/contexts/core.py` — update docstring for renamed property +- `src/orcapod/semantic_types/__init__.py` — remove `SemanticTypeRegistry` export +- `src/orcapod/protocols/semantic_types_protocols.py` — delete `SemanticStructConverterProtocol` + +**Deleted source:** +- `src/orcapod/semantic_types/semantic_struct_converters.py` +- `src/orcapod/semantic_types/semantic_registry.py` + +**Deleted tests:** +- `tests/test_semantic_types/` (all 9 files) +- `tests/test_hashing/test_file_hashing_consistency.py` + +**New tests:** +- `tests/test_hashing/test_extension_type_hashing.py` + +**Updated tests:** +- `tests/test_hashing/test_semantic_hasher.py` +- `tests/test_hashing/test_starfix_arrow_hasher.py` + +--- + +## Task 1: Rename `TypeHandlerProtocol` → `PythonTypeSemanticHasherProtocol` + +**Files:** +- Modify: `src/orcapod/protocols/hashing_protocols.py` + +- [ ] **Step 1: Rewrite the protocol class and update surrounding references** + +Replace the entire `TypeHandlerProtocol` class and update the `SemanticHasherProtocol`'s `type_handler_registry` property: + +```python +# In src/orcapod/protocols/hashing_protocols.py + +# Update TYPE_CHECKING import: +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeSemanticHasherRegistry + from orcapod.types import ContentHash # already imported at module level, just noting + +# Replace TypeHandlerProtocol with: +class PythonTypeSemanticHasherProtocol(Protocol): + """Protocol for type-specific semantic hashers used by SemanticAwarePythonHasher. + + A PythonTypeSemanticHasherProtocol hashes a specific Python type to a ``ContentHash``. + Implementations are registered with a ``PythonTypeSemanticHasherRegistry`` and looked + up via MRO-aware resolution. + + Each implementation receives the full ``SemanticAwarePythonHasher`` so it can delegate + hashing of sub-values (e.g. hashing a dict of function metadata) back to the outer + hasher without coupling to a specific hasher instance. + """ + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + """Hash *obj* to a ContentHash. + + Args: + obj: The object to hash. Always matches the registered type. + hasher: The active ``SemanticAwarePythonHasher``. Use + ``hasher.hash_object(sub_value)`` to hash sub-values. + + Returns: + ContentHash: The content-addressed hash of *obj*. + """ + ... + + +# Update SemanticHasherProtocol — rename the property: +class SemanticHasherProtocol(Protocol): + # ... existing methods unchanged ... + + @property + def type_semantic_hasher_registry(self) -> "PythonTypeSemanticHasherRegistry": + """Return the PythonTypeSemanticHasherRegistry used by this hasher.""" + ... +``` + +The full updated `hashing_protocols.py` (only `TypeHandlerProtocol` is renamed and `SemanticHasherProtocol.type_handler_registry` → `type_semantic_hasher_registry`; everything else is unchanged): + +```python +"""Hash strategy protocols for dependency injection.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from orcapod.types import ContentHash, PathLike, Schema + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeSemanticHasherRegistry + + +@runtime_checkable +class DataContextAwareProtocol(Protocol): + """Protocol for objects aware of their data context.""" + + @property + def data_context_key(self) -> str: + """Return the data context key associated with this object.""" + ... + + +@runtime_checkable +class PipelineElementProtocol(Protocol): + """Protocol for objects that have a stable identity as an element in a pipeline graph.""" + + def pipeline_identity_structure(self) -> Any: + """Return a structure representing this element's pipeline identity.""" + ... + + def pipeline_hash(self, hasher=None) -> ContentHash: + """Return the pipeline-level hash of this element.""" + ... + + +@runtime_checkable +class ContentIdentifiableProtocol(Protocol): + """Protocol for objects that can express their semantic identity as a plain Python structure.""" + + def identity_structure(self) -> Any: + """Return a structure that represents the semantic identity of this object.""" + ... + + def content_hash(self, hasher: "SemanticHasherProtocol | None" = None) -> ContentHash: + """Returns the content hash.""" + ... + + +class PythonTypeSemanticHasherProtocol(Protocol): + """Protocol for type-specific semantic hashers used by SemanticAwarePythonHasher. + + A ``PythonTypeSemanticHasherProtocol`` hashes a specific Python type to a + ``ContentHash``. Implementations are registered with a + ``PythonTypeSemanticHasherRegistry`` and looked up via MRO-aware resolution. + + Each implementation receives the full ``SemanticAwarePythonHasher`` so it can + delegate hashing of sub-values back to the outer hasher without coupling to a + specific hasher instance. + """ + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + """Hash *obj* to a ContentHash. + + Args: + obj: The object to hash. Always matches the registered type. + hasher: The active ``SemanticAwarePythonHasher``. Use + ``hasher.hash_object(sub_value)`` to hash sub-values. + + Returns: + ContentHash: The content-addressed hash of *obj*. + """ + ... + + +class SemanticHasherProtocol(Protocol): + """Protocol for the semantic content-based hasher.""" + + def hash_object( + self, + obj: Any, + resolver: Callable[[Any], ContentHash] | None = None, + ) -> ContentHash: + """Hash *obj* based on its semantic content.""" + ... + + @property + def hasher_id(self) -> str: + """Returns a unique identifier/name for this hasher instance.""" + ... + + @property + def type_semantic_hasher_registry(self) -> "PythonTypeSemanticHasherRegistry": + """Return the PythonTypeSemanticHasherRegistry used by this hasher.""" + ... + + +class FileContentHasherProtocol(Protocol): + """Protocol for file-related hashing.""" + + def hash_file(self, file_path: PathLike) -> ContentHash: ... + + +@runtime_checkable +class ArrowHasherProtocol(Protocol): + """Protocol for hashing arrow data.""" + + @property + def hasher_id(self) -> str: ... + + def hash_table(self, table: "pa.Table | pa.RecordBatch") -> ContentHash: ... + + +class StringCacherProtocol(Protocol): + """Protocol for caching string key value pairs.""" + + def get_cached(self, cache_key: str) -> str | None: ... + def set_cached(self, cache_key: str, value: str) -> None: ... + def clear_cache(self) -> None: ... + + +class FunctionInfoExtractorProtocol(Protocol): + """Protocol for extracting function information.""" + + def extract_function_info( + self, + func: Callable[..., Any], + function_name: str | None = None, + input_typespec: Schema | None = None, + output_typespec: Schema | None = None, + exclude_function_signature: bool = False, + exclude_function_body: bool = False, + ) -> dict[str, Any]: ... + + +class SemanticTypeHasherProtocol(Protocol): + """Abstract base class for semantic type-specific hashers.""" + + @property + def hasher_id(self) -> str: + """Unique identifier for this semantic type hasher.""" + ... + + def hash_column(self, column: "pa.Array") -> "pa.Array": + """Hash a column with this semantic type and return the hash bytes as an array.""" + ... + + def set_cacher(self, cacher: StringCacherProtocol) -> None: + """Add a string cacher for caching hash values.""" + ... +``` + +- [ ] **Step 2: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/protocols/hashing_protocols.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(hashing_protocols): rename TypeHandlerProtocol → PythonTypeSemanticHasherProtocol, tighten hash() → ContentHash" +``` + +--- + +## Task 2: Rename `TypeHandlerRegistry` → `PythonTypeSemanticHasherRegistry` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` + +- [ ] **Step 1: Rename the class, subclass, and all methods** + +Write the complete new file: + +```python +""" +PythonTypeSemanticHasherRegistry — MRO-aware registry for PythonTypeSemanticHasherProtocol instances. +""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + PythonTypeSemanticHasherProtocol, + ) + +logger = logging.getLogger(__name__) + + +class PythonTypeSemanticHasherRegistry: + """Registry mapping Python types to PythonTypeSemanticHasherProtocol instances. + + Lookup is MRO-aware: when no hasher is registered for the exact type of + an object, the registry walks the object's MRO (most-derived first) until + it finds a match. + + Thread safety + ------------- + Registration and lookup are protected by a reentrant lock so that the + global singleton can be safely used from multiple threads. + """ + + def __init__( + self, handlers: list[tuple[type, "PythonTypeSemanticHasherProtocol"]] | None = None + ) -> None: + """ + Args: + handlers: Optional list of ``(target_type, hasher)`` pairs to + register at construction time. + """ + self._handlers: dict[type, "PythonTypeSemanticHasherProtocol"] = {} + self._lock = threading.RLock() + if handlers: + for target_type, handler in handlers: + self.register(target_type, handler) + + def register(self, target_type: type, handler: "PythonTypeSemanticHasherProtocol") -> None: + """Register a hasher for a specific Python type. + + If a hasher is already registered for *target_type*, it is silently + replaced by the new hasher. + + Args: + target_type: The Python type (or class) for which the hasher should be used. + handler: A ``PythonTypeSemanticHasherProtocol`` instance. + + Raises: + TypeError: If ``target_type`` is not a ``type``. + """ + if not isinstance(target_type, type): + raise TypeError( + f"target_type must be a type/class, got {type(target_type)!r}" + ) + with self._lock: + existing = self._handlers.get(target_type) + if existing is not None and existing is not handler: + logger.debug( + "PythonTypeSemanticHasherRegistry: replacing existing hasher for %s (%s -> %s)", + target_type.__name__, + type(existing).__name__, + type(handler).__name__, + ) + self._handlers[target_type] = handler + + def unregister(self, target_type: type) -> bool: + """Remove the hasher registered for *target_type*, if any. + + Args: + target_type: The type whose hasher should be removed. + + Returns: + True if a hasher was removed, False if none was registered. + """ + with self._lock: + if target_type in self._handlers: + del self._handlers[target_type] + return True + return False + + def get_semantic_hasher(self, obj: Any) -> "PythonTypeSemanticHasherProtocol | None": + """Look up the hasher for *obj* using MRO-aware resolution. + + Args: + obj: The object for which a hasher is needed. + + Returns: + The registered ``PythonTypeSemanticHasherProtocol``, or None. + """ + obj_type = type(obj) + with self._lock: + handler = self._handlers.get(obj_type) + if handler is not None: + return handler + for base in obj_type.__mro__[1:]: + handler = self._handlers.get(base) + if handler is not None: + logger.debug( + "PythonTypeSemanticHasherRegistry: resolved hasher for %s via base %s", + obj_type.__name__, + base.__name__, + ) + return handler + return None + + def get_semantic_hasher_for_type( + self, target_type: type + ) -> "PythonTypeSemanticHasherProtocol | None": + """Look up the hasher for a *type object* (rather than an instance). + + Args: + target_type: The type to look up. + + Returns: + The registered ``PythonTypeSemanticHasherProtocol``, or None. + """ + with self._lock: + handler = self._handlers.get(target_type) + if handler is not None: + return handler + for base in target_type.__mro__[1:]: + handler = self._handlers.get(base) + if handler is not None: + return handler + return None + + def has_semantic_hasher(self, target_type: type) -> bool: + """Return True if a hasher is registered for *target_type* or any MRO ancestor. + + Args: + target_type: The type to check. + """ + return self.get_semantic_hasher_for_type(target_type) is not None + + def registered_types(self) -> list[type]: + """Return a list of all directly-registered types (no MRO expansion).""" + with self._lock: + return list(self._handlers.keys()) + + def __repr__(self) -> str: + with self._lock: + names = [t.__name__ for t in self._handlers] + return f"PythonTypeSemanticHasherRegistry(registered={names!r})" + + def __len__(self) -> int: + with self._lock: + return len(self._handlers) + + +def get_default_python_type_semantic_hasher_registry() -> "PythonTypeSemanticHasherRegistry": + """Return the PythonTypeSemanticHasherRegistry from the default data context. + + This is a convenience wrapper; the registry is owned and versioned by the + active ``DataContext``. Importing this function from + ``orcapod.hashing.defaults`` or ``orcapod.hashing`` is equivalent. + """ + from orcapod.hashing.defaults import ( + get_default_python_type_semantic_hasher_registry as _get, + ) + return _get() + + +class BuiltinPythonTypeSemanticHasherRegistry(PythonTypeSemanticHasherRegistry): + """A PythonTypeSemanticHasherRegistry pre-populated with all built-in hashers. + + Constructed via the data context JSON spec so that the default registry + is versioned alongside the rest of the context components. + """ + + def __init__(self, arrow_hasher: "ArrowHasherProtocol | None" = None) -> None: + super().__init__() + from orcapod.hashing.semantic_hashing.builtin_handlers import ( + register_builtin_python_type_semantic_hashers, + ) + register_builtin_python_type_semantic_hashers(self, arrow_hasher=arrow_hasher) +``` + +- [ ] **Step 2: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/semantic_hashing/type_handler_registry.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(type_handler_registry): rename to PythonTypeSemanticHasherRegistry, rename methods" +``` + +--- + +## Task 3: Rename + tighten all builtin handlers + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` + +- [ ] **Step 1: Write the complete updated file** + +Key changes: +- 11 class renames (all `*Handler`/`*ContentHandler` → `*SemanticHasher`) +- `handle(obj, hasher) -> Any` → `hash(obj, hasher) -> ContentHash` on every class +- `UUIDSemanticHasher`, `BytesSemanticHasher`, `FunctionSemanticHasher`, `TypeObjectSemanticHasher`, `SpecialFormSemanticHasher`, `GenericAliasSemanticHasher`, `UnionTypeSemanticHasher` now call `hasher.hash_object(...)` to return `ContentHash` directly +- `register_builtin_handlers` → `register_builtin_python_type_semantic_hashers` +- Remove `SemanticArrowHasher` fallback construction (it will be deleted); when `arrow_hasher is None`, skip registering `pa.Table`/`pa.RecordBatch` handlers + +```python +""" +Built-in PythonTypeSemanticHasherProtocol implementations. + + PathSemanticHasher -- pathlib.Path: file content hash + UPathSemanticHasher -- upath.UPath: file content hash (remote-aware) + UUIDSemanticHasher -- uuid.UUID: 16-byte binary representation + BytesSemanticHasher -- bytes/bytearray: hex string representation + FunctionSemanticHasher -- callable with __code__: via FunctionInfoExtractorProtocol + TypeObjectSemanticHasher -- type objects: stable "type:." string + SpecialFormSemanticHasher -- typing._SpecialForm + GenericAliasSemanticHasher -- generic alias type annotations + UnionTypeSemanticHasher -- types.UnionType (Python 3.10+ X | Y syntax) + ArrowTableSemanticHasher -- pa.Table / pa.RecordBatch + SchemaSemanticHasher -- Schema objects + +``register_builtin_python_type_semantic_hashers(registry)`` populates a registry +with all of the above. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from upath import UPath + +from orcapod.types import ContentHash, PathLike, Schema + +if TYPE_CHECKING: + from orcapod.hashing.semantic_hashing.type_handler_registry import ( + PythonTypeSemanticHasherRegistry, + ) + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher + from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + FileContentHasherProtocol, + ) + +logger = logging.getLogger(__name__) + + +class PathSemanticHasher: + """Hasher for pathlib.Path objects — hashes file *content*. + + Args: + file_hasher: Any object with a ``hash_file(path) -> ContentHash`` method. + """ + + def __init__(self, file_hasher: "FileContentHasherProtocol") -> None: + self.file_hasher = file_hasher + + def hash(self, obj: PathLike, hasher: "SemanticAwarePythonHasher") -> ContentHash: + path: Path = Path(obj) + if not path.exists(): + raise FileNotFoundError( + f"PathSemanticHasher: path does not exist: {path!r}. " + "Paths must refer to existing files for content-based hashing." + ) + if path.is_dir(): + raise IsADirectoryError( + f"PathSemanticHasher: path is a directory: {path!r}. " + "Only regular files are supported for content-based hashing." + ) + logger.debug("PathSemanticHasher: hashing file content at %s", path) + return self.file_hasher.hash_file(path) + + +class UPathSemanticHasher: + """Hasher for universal_pathlib.UPath objects — hashes file content. + + Args: + file_hasher: Any object with a ``hash_file(path) -> ContentHash`` method. + """ + + def __init__(self, file_hasher: "FileContentHasherProtocol") -> None: + self.file_hasher = file_hasher + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + if not isinstance(obj, UPath): + raise TypeError( + f"UPathSemanticHasher: expected a UPath, got {type(obj)!r}." + ) + if not obj.exists(): + raise FileNotFoundError( + f"UPathSemanticHasher: path does not exist: {obj!r}." + ) + if obj.is_dir(): + raise IsADirectoryError( + f"UPathSemanticHasher: path is a directory: {obj!r}." + ) + logger.debug("UPathSemanticHasher: hashing file content at %s", obj) + return self.file_hasher.hash_file(obj) + + +class UUIDSemanticHasher: + """Hasher for ``uuid.UUID`` objects — hashes the raw 16-byte binary representation.""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + return hasher.hash_object(obj.bytes) + + +class BytesSemanticHasher: + """Hasher for bytes and bytearray objects — hashes the lowercase hex representation.""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + if isinstance(obj, (bytes, bytearray)): + return hasher.hash_object(obj.hex()) + raise TypeError( + f"BytesSemanticHasher: expected bytes or bytearray, got {type(obj)!r}" + ) + + +class FunctionSemanticHasher: + """Hasher for Python functions/callables with a ``__code__`` attribute. + + Args: + function_info_extractor: Any object with an + ``extract_function_info(func) -> dict`` method. + """ + + def __init__(self, function_info_extractor: Any) -> None: + self.function_info_extractor = function_info_extractor + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + if not (callable(obj) and hasattr(obj, "__code__")): + raise TypeError( + f"FunctionSemanticHasher: expected a callable with __code__, got {type(obj)!r}" + ) + func_name = getattr(obj, "__name__", repr(obj)) + logger.debug("FunctionSemanticHasher: extracting info for function %r", func_name) + info: dict[str, Any] = self.function_info_extractor.extract_function_info(obj) + return hasher.hash_object(info) + + +class TypeObjectSemanticHasher: + """Hasher for type objects (classes passed as values). + + Returns a stable string of the form ``"type:."``. + """ + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + if not isinstance(obj, type): + raise TypeError( + f"TypeObjectSemanticHasher: expected a type/class, got {type(obj)!r}" + ) + module: str = obj.__module__ or "" + qualname: str = obj.__qualname__ + return hasher.hash_object(f"type:{module}.{qualname}") + + +class SpecialFormSemanticHasher: + """Hasher for ``typing._SpecialForm`` objects such as ``typing.Union``.""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + name = getattr(obj, "_name", None) or repr(obj) + return hasher.hash_object(f"special_form:typing.{name}") + + +class GenericAliasSemanticHasher: + """Hasher for generic alias type annotations (``dict[int, str]``, ``Optional[X]``, etc.).""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + import typing + + origin = getattr(obj, "__origin__", None) + args = getattr(obj, "__args__", None) or () + if origin is None: + return hasher.hash_object(f"generic_alias:{obj!r}") + if origin is typing.Union: + hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) + return hasher.hash_object({"__type__": "union", "args": hashed_args}) + return hasher.hash_object({ + "__type__": "generic_alias", + "origin": hasher.hash_object(origin).to_string(), + "args": [hasher.hash_object(arg).to_string() for arg in args], + }) + + +class UnionTypeSemanticHasher: + """Hasher for ``types.UnionType`` objects (Python 3.10+ ``X | Y`` syntax).""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + args = getattr(obj, "__args__", None) or () + hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) + return hasher.hash_object({"__type__": "union", "args": hashed_args}) + + +class ArrowTableSemanticHasher: + """Hasher for ``pa.Table`` and ``pa.RecordBatch`` objects. + + Args: + arrow_hasher: Any object satisfying ``ArrowHasherProtocol``. + """ + + def __init__(self, arrow_hasher: "ArrowHasherProtocol") -> None: + self.arrow_hasher = arrow_hasher + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + import pyarrow as _pa + + if isinstance(obj, _pa.RecordBatch): + obj = _pa.Table.from_batches([obj]) + if not isinstance(obj, _pa.Table): + raise TypeError( + f"ArrowTableSemanticHasher: expected pa.Table or pa.RecordBatch, got {type(obj)!r}" + ) + return self.arrow_hasher.hash_table(obj) + + +class SchemaSemanticHasher: + """Hasher for ``Schema`` objects.""" + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + if not isinstance(obj, Schema): + raise TypeError( + f"SchemaSemanticHasher: expected a Schema, got {type(obj)!r}" + ) + raise NotImplementedError("SchemaSemanticHasher is not yet implemented.") + + +def register_builtin_python_type_semantic_hashers( + registry: "PythonTypeSemanticHasherRegistry", + file_hasher: Any = None, + function_info_extractor: Any = None, + arrow_hasher: "ArrowHasherProtocol | None" = None, +) -> None: + """Register all built-in semantic hashers into *registry*. + + When ``arrow_hasher`` is None, ``pa.Table`` and ``pa.RecordBatch`` handlers + are **not** registered (to avoid circular dependency in the JSON context + construction — the default context's ``python_type_semantic_hasher_registry`` + is built before ``arrow_hasher``). + + Args: + registry: The ``PythonTypeSemanticHasherRegistry`` to populate. + file_hasher: Optional ``FileContentHasherProtocol`` for path hashing. + Defaults to ``BasicFileHasher(sha256)``. + function_info_extractor: Optional ``FunctionInfoExtractorProtocol``. + Defaults to ``FunctionSignatureExtractor``. + arrow_hasher: Optional ``ArrowHasherProtocol`` for nested table hashing. + When None, Arrow table handlers are skipped. + """ + if file_hasher is None: + from orcapod.hashing.file_hashers import BasicFileHasher + file_hasher = BasicFileHasher(algorithm="sha256") + + if function_info_extractor is None: + from orcapod.hashing.semantic_hashing.function_info_extractors import ( + FunctionSignatureExtractor, + ) + function_info_extractor = FunctionSignatureExtractor( + include_module=True, + include_defaults=True, + ) + + bytes_hasher = BytesSemanticHasher() + registry.register(bytes, bytes_hasher) + registry.register(bytearray, bytes_hasher) + + registry.register(Path, PathSemanticHasher(file_hasher)) + registry.register(UPath, UPathSemanticHasher(file_hasher)) + registry.register(UUID, UUIDSemanticHasher()) + + import types as _types + + function_hasher = FunctionSemanticHasher(function_info_extractor) + registry.register(_types.FunctionType, function_hasher) + registry.register(_types.BuiltinFunctionType, function_hasher) + registry.register(_types.MethodType, function_hasher) + + registry.register(type, TypeObjectSemanticHasher()) + registry.register(_types.UnionType, UnionTypeSemanticHasher()) + + generic_alias_hasher = GenericAliasSemanticHasher() + registry.register(_types.GenericAlias, generic_alias_hasher) + try: + import typing as _typing + registry.register(_typing._GenericAlias, generic_alias_hasher) # type: ignore[attr-defined] + registry.register(_typing._SpecialForm, SpecialFormSemanticHasher()) # type: ignore[attr-defined] + except AttributeError: + pass + + registry.register(Schema, SchemaSemanticHasher()) + + if arrow_hasher is not None: + import pyarrow as _pa + arrow_table_hasher = ArrowTableSemanticHasher(arrow_hasher) + registry.register(_pa.Table, arrow_table_hasher) + registry.register(_pa.RecordBatch, arrow_table_hasher) + + logger.debug( + "register_builtin_python_type_semantic_hashers: registered %d hashers", + len(registry), + ) +``` + +- [ ] **Step 2: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/semantic_hashing/builtin_handlers.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(builtin_handlers): rename handler classes, tighten hash() → ContentHash" +``` + +--- + +## Task 4: Rename `BaseSemanticHasher` → `SemanticAwarePythonHasher`, simplify dispatch + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/semantic_hasher.py` + +- [ ] **Step 1: Apply renames and simplify hash_object dispatch** + +Changes: +1. Class name `BaseSemanticHasher` → `SemanticAwarePythonHasher` +2. `__init__` parameter `type_handler_registry` → `type_semantic_hasher_registry` +3. `self._registry = get_default_type_handler_registry()` → `get_default_python_type_semantic_hasher_registry()` +4. `type_handler_registry` property → `type_semantic_hasher_registry` +5. Return type annotation `TypeHandlerRegistry` → `PythonTypeSemanticHasherRegistry` +6. `hash_object` dispatch: `get_handler` → `get_semantic_hasher`; remove double-wrap (handler now returns `ContentHash` directly) + +The dispatch block in `hash_object` changes from: +```python +handler = self._registry.get_handler(obj) +if handler is not None: + return self.hash_object(handler.handle(obj, self), resolver=resolver) +``` +to: +```python +semantic_hasher = self._registry.get_semantic_hasher(obj) +if semantic_hasher is not None: + return semantic_hasher.hash(obj, self) +``` + +Full updated file (only showing the changed parts — keep everything else identical): + +```python +# At top of file, update import: +from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeSemanticHasherRegistry + +# Class rename: +class SemanticAwarePythonHasher: + """ + Content-based recursive hasher. + [same docstring, just update BaseSemanticHasher references to SemanticAwarePythonHasher] + """ + + def __init__( + self, + hasher_id: str, + type_semantic_hasher_registry: PythonTypeSemanticHasherRegistry | None = None, + strict: bool = True, + ) -> None: + self._hasher_id = hasher_id + self._strict = strict + + if type_semantic_hasher_registry is None: + from orcapod.hashing.defaults import get_default_python_type_semantic_hasher_registry + self._registry = get_default_python_type_semantic_hasher_registry() + else: + self._registry = type_semantic_hasher_registry + + @property + def hasher_id(self) -> str: + return self._hasher_id + + @property + def strict(self) -> bool: + return self._strict + + @property + def type_semantic_hasher_registry(self) -> PythonTypeSemanticHasherRegistry: + """Return the ``PythonTypeSemanticHasherRegistry`` used by this hasher.""" + return self._registry + + def hash_object(self, obj, resolver=None): + # ... keep all existing logic, EXCEPT replace the handler dispatch block: + + # Old: + # handler = self._registry.get_handler(obj) + # if handler is not None: + # return self.hash_object(handler.handle(obj, self), resolver=resolver) + + # New: + # semantic_hasher = self._registry.get_semantic_hasher(obj) + # if semantic_hasher is not None: + # return semantic_hasher.hash(obj, self) + ... +``` + +The complete updated `hash_object` method (copy the full existing body, changing only the handler dispatch): + +```python +def hash_object( + self, + obj: Any, + resolver: Callable[[Any], ContentHash] | None = None, +) -> ContentHash: + """Hash *obj* based on its semantic content.""" + # Terminal: already a hash -- return as-is. + if isinstance(obj, ContentHash): + return obj + + # Primitives: hash their direct JSON representation. + if isinstance(obj, (type(None), bool, int, float, str)): + return self._hash_to_content_hash(obj) + + # Structures: expand into a tagged tree, then hash the tree. + if _is_structure(obj): + expanded = self._expand_structure( + obj, _visited=frozenset(), resolver=resolver + ) + return self._hash_to_content_hash(expanded) + + # Semantic hasher dispatch: the hasher produces a ContentHash directly. + semantic_hasher = self._registry.get_semantic_hasher(obj) + if semantic_hasher is not None: + logger.debug( + "hash_object: dispatching %s to semantic hasher %s", + type(obj).__name__, + type(semantic_hasher).__name__, + ) + return semantic_hasher.hash(obj, self) + + # ContentIdentifiableProtocol: use resolver if provided, else content_hash(). + if isinstance(obj, hp.ContentIdentifiableProtocol): + if resolver is not None: + logger.debug( + "hash_object: resolving ContentIdentifiableProtocol %s via resolver", + type(obj).__name__, + ) + return resolver(obj) + else: + logger.debug( + "hash_object: using ContentIdentifiableProtocol %s's content_hash", + type(obj).__name__, + ) + return obj.content_hash() + + # Fallback for unhandled types. + fallback = self._handle_unknown(obj) + return self._hash_to_content_hash(fallback) +``` + +- [ ] **Step 2: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/semantic_hashing/semantic_hasher.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(semantic_hasher): rename BaseSemanticHasher → SemanticAwarePythonHasher, simplify dispatch" +``` + +--- + +## Task 5: Update `content_identifiable_mixin.py` and `contexts/core.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py` +- Modify: `src/orcapod/contexts/core.py` + +- [ ] **Step 1: Update `content_identifiable_mixin.py`** + +Three changes: +1. Line 68: `from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher` → `SemanticAwarePythonHasher` +2. Line 97: parameter `semantic_hasher: BaseSemanticHasher | None` → `SemanticAwarePythonHasher | None` +3. Line 218 (approximately): `def _get_hasher(self) -> BaseSemanticHasher:` → `SemanticAwarePythonHasher` +4. Update the class docstring reference from `BaseSemanticHasher` to `SemanticAwarePythonHasher` + +```python +# Old line 68: +from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher + +# New: +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +``` + +```python +# Old __init__ signature: +def __init__( + self, *, semantic_hasher: BaseSemanticHasher | None = None, **kwargs: Any +) -> None: + +# New: +def __init__( + self, *, semantic_hasher: SemanticAwarePythonHasher | None = None, **kwargs: Any +) -> None: +``` + +Also update the `_get_hasher` return type annotation and any docstring mentions of `BaseSemanticHasher`. + +- [ ] **Step 2: Update `contexts/core.py` docstring** + +Update the `DataContext` docstring — replace `semantic_hasher.type_handler_registry` with `semantic_hasher.type_semantic_hasher_registry`: + +```python +@dataclass +class DataContext: + """Data context containing all versioned components needed for data interpretation. + + Attributes: + context_key: Unique identifier (e.g., "std:v0.1:default") + version: Version string (e.g., "v0.1") + description: Human-readable description + type_converter: Type converter for Python ↔ Arrow conversion and registration. + arrow_hasher: Arrow table hasher for this context. + semantic_hasher: General semantic hasher for this context. The + ``PythonTypeSemanticHasherRegistry`` used for hashing is accessible via + ``semantic_hasher.type_semantic_hasher_registry``. + """ +``` + +- [ ] **Step 3: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py \ + src/orcapod/contexts/core.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor: update BaseSemanticHasher → SemanticAwarePythonHasher refs in mixin and core" +``` + +--- + +## Task 6: Update `__init__.py` exports and `defaults.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/__init__.py` +- Modify: `src/orcapod/hashing/__init__.py` +- Modify: `src/orcapod/hashing/defaults.py` + +- [ ] **Step 1: Update `semantic_hashing/__init__.py`** + +```python +""" +orcapod.hashing.semantic_hashing +================================= + SemanticAwarePythonHasher -- content-based recursive object hasher + PythonTypeSemanticHasherRegistry -- MRO-aware registry mapping types → PythonTypeSemanticHasherProtocol + BuiltinPythonTypeSemanticHasherRegistry -- pre-populated registry with built-in hashers + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects + +Built-in PythonTypeSemanticHasherProtocol implementations: + PathSemanticHasher -- pathlib.Path → file-content hash + UUIDSemanticHasher -- uuid.UUID → canonical bytes + BytesSemanticHasher -- bytes/bytearray → hex string + FunctionSemanticHasher -- callable → via FunctionInfoExtractorProtocol + TypeObjectSemanticHasher -- type objects → "type:." + register_builtin_python_type_semantic_hashers -- populate a registry with all of the above + +Function info extractors (used by FunctionSemanticHasher): + FunctionNameExtractor + FunctionSignatureExtractor + FunctionInfoExtractorFactory +""" + +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + BytesSemanticHasher, + FunctionSemanticHasher, + PathSemanticHasher, + TypeObjectSemanticHasher, + UUIDSemanticHasher, + register_builtin_python_type_semantic_hashers, +) +from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( + ContentIdentifiableMixin, +) +from orcapod.hashing.semantic_hashing.function_info_extractors import ( + FunctionInfoExtractorFactory, + FunctionNameExtractor, + FunctionSignatureExtractor, +) +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinPythonTypeSemanticHasherRegistry, + PythonTypeSemanticHasherRegistry, +) + +__all__ = [ + "SemanticAwarePythonHasher", + "PythonTypeSemanticHasherRegistry", + "BuiltinPythonTypeSemanticHasherRegistry", + "ContentIdentifiableMixin", + "PathSemanticHasher", + "UUIDSemanticHasher", + "BytesSemanticHasher", + "FunctionSemanticHasher", + "TypeObjectSemanticHasher", + "register_builtin_python_type_semantic_hashers", + "FunctionNameExtractor", + "FunctionSignatureExtractor", + "FunctionInfoExtractorFactory", +] +``` + +- [ ] **Step 2: Update `hashing/__init__.py`** + +```python +""" +OrcaPod hashing package. + +Public API +---------- + SemanticAwarePythonHasher -- content-based recursive object hasher + SemanticHasherProtocol -- protocol for semantic hashers + PythonTypeSemanticHasherRegistry -- registry mapping types to PythonTypeSemanticHasherProtocol instances + get_default_semantic_hasher -- global default SemanticHasherProtocol factory + get_default_python_type_semantic_hasher_registry -- global default registry factory + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects + +Built-in hashers (importable for custom registry setup): + PathSemanticHasher + UUIDSemanticHasher + BytesSemanticHasher + FunctionSemanticHasher + TypeObjectSemanticHasher + register_builtin_python_type_semantic_hashers + +Utility: + FileContentHasherProtocol + StringCacherProtocol + FunctionInfoExtractorProtocol + ArrowHasherProtocol +""" + +from orcapod.hashing.defaults import ( + get_default_arrow_hasher, + get_default_python_type_semantic_hasher_registry, + get_default_semantic_hasher, +) +from orcapod.hashing.file_hashers import BasicFileHasher, CachedFileHasher +from orcapod.hashing.hash_utils import hash_file +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + BytesSemanticHasher, + FunctionSemanticHasher, + PathSemanticHasher, + TypeObjectSemanticHasher, + UUIDSemanticHasher, + register_builtin_python_type_semantic_hashers, +) +from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( + ContentIdentifiableMixin, +) +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinPythonTypeSemanticHasherRegistry, + PythonTypeSemanticHasherRegistry, +) +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + ContentIdentifiableProtocol, + FileContentHasherProtocol, + FunctionInfoExtractorProtocol, + PythonTypeSemanticHasherProtocol, + SemanticHasherProtocol, + SemanticTypeHasherProtocol, + StringCacherProtocol, +) + +try: + from orcapod.hashing.legacy_core import ( + HashableMixin, + function_content_hash, + get_function_signature, + hash_function, + hash_data, + hash_pathset, + hash_to_hex, + hash_to_int, + hash_to_uuid, + ) +except ImportError: + HashableMixin = None # type: ignore[assignment,misc] + function_content_hash = None # type: ignore[assignment] + get_function_signature = None # type: ignore[assignment] + hash_function = None # type: ignore[assignment] + hash_data = None # type: ignore[assignment] + hash_pathset = None # type: ignore[assignment] + hash_to_hex = None # type: ignore[assignment] + hash_to_int = None # type: ignore[assignment] + hash_to_uuid = None # type: ignore[assignment] + +__all__ = [ + "SemanticAwarePythonHasher", + "PythonTypeSemanticHasherRegistry", + "BuiltinPythonTypeSemanticHasherRegistry", + "get_default_python_type_semantic_hasher_registry", + "get_default_semantic_hasher", + "ContentIdentifiableMixin", + "PathSemanticHasher", + "UUIDSemanticHasher", + "BytesSemanticHasher", + "FunctionSemanticHasher", + "TypeObjectSemanticHasher", + "register_builtin_python_type_semantic_hashers", + "SemanticHasherProtocol", + "ContentIdentifiableProtocol", + "PythonTypeSemanticHasherProtocol", + "FileContentHasherProtocol", + "ArrowHasherProtocol", + "StringCacherProtocol", + "FunctionInfoExtractorProtocol", + "SemanticTypeHasherProtocol", + "BasicFileHasher", + "CachedFileHasher", + "hash_file", + "get_default_arrow_hasher", + "HashableMixin", + "hash_to_hex", + "hash_to_int", + "hash_to_uuid", + "hash_function", + "get_function_signature", + "function_content_hash", + "hash_pathset", + "hash_data", +] +``` + +- [ ] **Step 3: Update `hashing/defaults.py`** + +```python +# Default hasher accessors for the OrcaPod hashing system. + +from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeSemanticHasherRegistry +from orcapod.protocols import hashing_protocols as hp + + +def get_default_python_type_semantic_hasher_registry() -> PythonTypeSemanticHasherRegistry: + """Return the PythonTypeSemanticHasherRegistry from the default data context's semantic hasher. + + Returns: + PythonTypeSemanticHasherRegistry: The registry from the default data context. + """ + from orcapod.contexts import get_default_context + return get_default_context().semantic_hasher.type_semantic_hasher_registry + + +def get_default_semantic_hasher() -> hp.SemanticHasherProtocol: + """Return the SemanticHasherProtocol from the default data context.""" + from orcapod.contexts import get_default_context + return get_default_context().semantic_hasher + + +def get_default_arrow_hasher() -> hp.ArrowHasherProtocol: + """Return the ArrowHasherProtocol from the default data context. + + Note: file-hash caching (formerly via ``set_cacher``) has been removed. + ``StarfixArrowHasher`` does not support per-path caching. Use + ``CachedFileHasher`` when constructing a custom context if caching is needed. + """ + from orcapod.contexts import get_default_context + return get_default_context().arrow_hasher +``` + +- [ ] **Step 4: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/semantic_hashing/__init__.py \ + src/orcapod/hashing/__init__.py \ + src/orcapod/hashing/defaults.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(hashing): update __init__.py exports and defaults for rename" +``` + +--- + +## Task 7: Update `test_semantic_hasher.py` → run tests + +**Files:** +- Modify: `tests/test_hashing/test_semantic_hasher.py` + +- [ ] **Step 1: Update imports at the top of the file** + +```python +# Old: +from orcapod.hashing.semantic_hashing.builtin_handlers import register_builtin_handlers +from orcapod.hashing.semantic_hashing.semantic_hasher import ( + BaseSemanticHasher, + _is_namedtuple, +) +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + TypeHandlerRegistry, + get_default_type_handler_registry, +) + +# New: +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + register_builtin_python_type_semantic_hashers, +) +from orcapod.hashing.semantic_hashing.semantic_hasher import ( + SemanticAwarePythonHasher, + _is_namedtuple, +) +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + PythonTypeSemanticHasherRegistry, + get_default_python_type_semantic_hasher_registry, +) +``` + +- [ ] **Step 2: Update `make_hasher()` fixture and type annotations** + +```python +def make_hasher(strict: bool = True) -> SemanticAwarePythonHasher: + """Create a fresh SemanticAwarePythonHasher with an isolated registry.""" + registry = PythonTypeSemanticHasherRegistry() + register_builtin_python_type_semantic_hashers(registry) + return SemanticAwarePythonHasher( + hasher_id="test_v1", type_semantic_hasher_registry=registry, strict=strict + ) + + +@pytest.fixture +def hasher() -> SemanticAwarePythonHasher: + return make_hasher(strict=True) + + +@pytest.fixture +def lenient_hasher() -> SemanticAwarePythonHasher: + return make_hasher(strict=False) +``` + +- [ ] **Step 3: Update `_DummyHandler` in `TestTypeHandlerRegistry` (near line 827)** + +```python +# Old: +class _DummyHandler: + def __init__(self, tag: str) -> None: + self.tag = tag + + def handle(self, obj: Any, hasher: Any) -> Any: + return f"{self.tag}:{obj}" + +# New: +class _DummySemanticHasher: + def __init__(self, tag: str) -> None: + self.tag = tag + + def hash(self, obj: Any, hasher: Any) -> Any: + # Returns a ContentHash by delegating to the outer hasher + return hasher.hash_object(f"{self.tag}:{obj}") +``` + +- [ ] **Step 4: Update `TestTypeHandlerRegistry` class — rename class, method calls, and dummy handler** + +Rename the test class to `TestPythonTypeSemanticHasherRegistry` and update every reference: +- `TypeHandlerRegistry()` → `PythonTypeSemanticHasherRegistry()` +- `_DummyHandler(...)` → `_DummySemanticHasher(...)` +- `reg.get_handler(...)` → `reg.get_semantic_hasher(...)` +- `reg.has_handler(...)` → `reg.has_semantic_hasher(...)` +- `reg.get_handler_for_type(...)` → `reg.get_semantic_hasher_for_type(...)` + +Example of updated test methods: +```python +class TestPythonTypeSemanticHasherRegistry: + def test_register_and_get_exact(self): + reg = PythonTypeSemanticHasherRegistry() + h = _DummySemanticHasher("base") + reg.register(Base, h) + assert reg.get_semantic_hasher(Base()) is h + + def test_mro_lookup_child(self): + reg = PythonTypeSemanticHasherRegistry() + h = _DummySemanticHasher("base") + reg.register(Base, h) + assert reg.get_semantic_hasher(Child()) is h + + def test_mro_lookup_grandchild(self): + reg = PythonTypeSemanticHasherRegistry() + h = _DummySemanticHasher("base") + reg.register(Base, h) + assert reg.get_semantic_hasher(GrandChild()) is h + + def test_more_specific_handler_wins(self): + reg = PythonTypeSemanticHasherRegistry() + h_base = _DummySemanticHasher("base") + h_child = _DummySemanticHasher("child") + reg.register(Base, h_base) + reg.register(Child, h_child) + assert reg.get_semantic_hasher(Child()) is h_child + assert reg.get_semantic_hasher(GrandChild()) is h_child + + def test_unregistered_returns_none(self): + reg = PythonTypeSemanticHasherRegistry() + assert reg.get_semantic_hasher(Base()) is None + + def test_unregister_removes_handler(self): + reg = PythonTypeSemanticHasherRegistry() + h = _DummySemanticHasher("base") + reg.register(Base, h) + assert reg.unregister(Base) is True + assert reg.get_semantic_hasher(Base()) is None + + def test_unregister_nonexistent_returns_false(self): + reg = PythonTypeSemanticHasherRegistry() + assert reg.unregister(Base) is False + + def test_replace_existing_handler(self): + reg = PythonTypeSemanticHasherRegistry() + h1 = _DummySemanticHasher("first") + h2 = _DummySemanticHasher("second") + reg.register(Base, h1) + reg.register(Base, h2) + assert reg.get_semantic_hasher(Base()) is h2 + + def test_register_non_type_raises(self): + reg = PythonTypeSemanticHasherRegistry() + with pytest.raises(TypeError): + reg.register("not_a_type", _DummySemanticHasher("x")) # type: ignore[arg-type] + + def test_has_semantic_hasher_exact(self): + reg = PythonTypeSemanticHasherRegistry() + reg.register(Base, _DummySemanticHasher("b")) + assert reg.has_semantic_hasher(Base) is True + + def test_has_semantic_hasher_via_mro(self): + reg = PythonTypeSemanticHasherRegistry() + reg.register(Base, _DummySemanticHasher("b")) + assert reg.has_semantic_hasher(Child) is True + + def test_has_semantic_hasher_false(self): + reg = PythonTypeSemanticHasherRegistry() + assert reg.has_semantic_hasher(Base) is False + + def test_registered_types_snapshot(self): + reg = PythonTypeSemanticHasherRegistry() + reg.register(Base, _DummySemanticHasher("b")) + reg.register(Child, _DummySemanticHasher("c")) + types = reg.registered_types() + assert Base in types + assert Child in types + + def test_len(self): + reg = PythonTypeSemanticHasherRegistry() + assert len(reg) == 0 + reg.register(Base, _DummySemanticHasher("b")) + assert len(reg) == 1 + reg.register(Child, _DummySemanticHasher("c")) + assert len(reg) == 2 + + def test_get_semantic_hasher_for_type(self): + reg = PythonTypeSemanticHasherRegistry() + h = _DummySemanticHasher("b") + reg.register(Base, h) + assert reg.get_semantic_hasher_for_type(Base) is h + assert reg.get_semantic_hasher_for_type(Child) is h # via MRO + assert reg.get_semantic_hasher_for_type(int) is None +``` + +Also update any remaining references in the file body to `get_default_type_handler_registry` → `get_default_python_type_semantic_hasher_registry`, and any fixture type annotations. + +- [ ] **Step 5: Run tests** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/test_semantic_hasher.py -x -v +``` + +Expected: all tests pass. + +- [ ] **Step 6: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add tests/test_hashing/test_semantic_hasher.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "test(semantic_hasher): update for registry rename and hash() protocol tightening" +``` + +--- + +## Task 8: Add `visit_extension` to `ArrowTypeDataVisitor` + rewrite `SemanticHashingVisitor` + +**Files:** +- Modify: `src/orcapod/hashing/visitors.py` + +- [ ] **Step 1: Write a failing test for `visit_extension` dispatch** + +Create `tests/test_hashing/test_extension_type_hashing.py`: + +```python +"""Tests for extension type column hashing via SemanticHashingVisitor.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest +from pathlib import Path + +from orcapod.hashing.visitors import SemanticHashingVisitor +from orcapod.contexts import get_default_context + + +@pytest.fixture +def ctx(): + return get_default_context() + + +class TestArrowTypeDataVisitorExtension: + def test_visit_dispatches_to_visit_extension_for_extension_types(self, ctx): + """visit() routes ExtensionType columns to visit_extension(), not visit_struct().""" + arrow_type = ctx.type_converter.register_python_class(Path) + assert isinstance(arrow_type, pa.ExtensionType), ( + "Path must be registered as an Arrow extension type" + ) + + calls = [] + + class TrackingVisitor(SemanticHashingVisitor): + def visit_extension(self, ext_type, storage_value): + calls.append("visit_extension") + return super().visit_extension(ext_type, storage_value) + + def visit_struct(self, struct_type, data): + calls.append("visit_struct") + return super().visit_struct(struct_type, data) + + visitor = TrackingVisitor(ctx.type_converter, ctx.semantic_hasher) + # Any value is fine for this dispatch test — use a dummy string (storage for Path is str) + visitor.visit(arrow_type, "/tmp/dummy") + assert "visit_extension" in calls + assert "visit_struct" not in calls + + +class TestSemanticHashingVisitorExtension: + def test_path_column_hashed_to_large_binary(self, ctx, tmp_path): + """Path extension columns are replaced with pa.large_binary() hash tokens.""" + file = tmp_path / "test.txt" + file.write_text("hello") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + new_type, new_data = visitor.visit(arrow_type, storage_val) + + assert new_type == pa.large_binary() + assert isinstance(new_data, bytes) + + def test_same_content_same_hash(self, ctx, tmp_path): + """Two paths pointing to files with identical content produce the same hash bytes.""" + file1 = tmp_path / "a.txt" + file2 = tmp_path / "b.txt" + file1.write_text("identical content") + file2.write_text("identical content") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage1 = ctx.type_converter.python_to_storage(Path(file1), Path) + storage2 = ctx.type_converter.python_to_storage(Path(file2), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash1 = visitor.visit(arrow_type, storage1) + _, hash2 = visitor.visit(arrow_type, storage2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self, ctx, tmp_path): + """Files with different content produce different hash bytes.""" + file1 = tmp_path / "x.txt" + file2 = tmp_path / "y.txt" + file1.write_text("content A") + file2.write_text("content B") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage1 = ctx.type_converter.python_to_storage(Path(file1), Path) + storage2 = ctx.type_converter.python_to_storage(Path(file2), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash1 = visitor.visit(arrow_type, storage1) + _, hash2 = visitor.visit(arrow_type, storage2) + + assert hash1 != hash2 + + def test_binary_encoding_format(self, ctx, tmp_path): + """Hash bytes have format b':::'.""" + file = tmp_path / "test.txt" + file.write_text("test") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash_bytes = visitor.visit(arrow_type, storage_val) + + assert b"::" in hash_bytes + type_prefix, hash_part = hash_bytes.split(b"::", 1) + # Extension name "orcapod.path" → dots replaced with colons + assert type_prefix == b"orcapod:path" + # hash_part should be "method:digest" — at least one colon + assert b":" in hash_part + + def test_null_value_passthrough(self, ctx): + """Null storage values pass through as-is.""" + arrow_type = ctx.type_converter.register_python_class(Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + new_type, new_data = visitor.visit(arrow_type, None) + + assert new_type == arrow_type + assert new_data is None +``` + +- [ ] **Step 2: Run tests — verify they fail** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/test_extension_type_hashing.py -x -v +``` + +Expected: ImportError or AttributeError (methods don't exist yet). + +- [ ] **Step 3: Rewrite `visitors.py`** + +```python +""" +Generic visitor pattern for traversing Arrow types and data simultaneously. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +else: + pa = LazyModule("pyarrow") + + +class ArrowTypeDataVisitor(ABC): + """Base visitor for traversing Arrow types and data simultaneously.""" + + @abstractmethod + def visit_struct( + self, struct_type: "pa.StructType", data: dict | None + ) -> tuple["pa.DataType", Any]: + """Visit a struct type with its data.""" + pass + + @abstractmethod + def visit_list( + self, list_type: "pa.ListType", data: list | None + ) -> tuple["pa.DataType", Any]: + """Visit a list type with its data.""" + pass + + @abstractmethod + def visit_map( + self, map_type: "pa.MapType", data: dict | None + ) -> tuple["pa.DataType", Any]: + """Visit a map type with its data.""" + pass + + @abstractmethod + def visit_primitive( + self, primitive_type: "pa.DataType", data: Any + ) -> tuple["pa.DataType", Any]: + """Visit a primitive type with its data.""" + pass + + def visit_extension( + self, + extension_type: "pa.ExtensionType", + storage_value: Any, + ) -> tuple["pa.DataType", Any]: + """Handle an Arrow extension type. + + Default implementation: passthrough — preserves the extension type and its + storage value unchanged so that the downstream ``StarfixArrowHasher`` / + ``ArrowDigester`` sees the full extension metadata when it receives the + pre-processed table. + + Subclasses may override to convert recognised extension types to a hashed + ``pa.large_binary()`` value. + + Args: + extension_type: The Arrow extension type. + storage_value: The storage-level value (result of ``to_pylist()`` on the column). + + Returns: + Tuple of ``(new_arrow_type, new_data)``. + """ + return extension_type, storage_value + + def visit(self, arrow_type: "pa.DataType", data: Any) -> tuple["pa.DataType", Any]: + """Main dispatch method that routes to the appropriate visit method. + + Extension types are checked **first** — before the struct check — because + extension types with struct storage would otherwise be incorrectly routed + into ``visit_struct``. After ``visit_extension``, the result is re-visited + only if the type changed AND is no longer an extension type (enables + composability, avoids infinite recursion). + + Args: + arrow_type: Arrow data type to process. + data: Corresponding data value. + + Returns: + Tuple of ``(new_arrow_type, new_data)``. + """ + if isinstance(arrow_type, pa.ExtensionType): + new_type, new_data = self.visit_extension(arrow_type, data) + if new_type is not arrow_type and not isinstance(new_type, pa.ExtensionType): + return self.visit(new_type, new_data) + return new_type, new_data + + if pa.types.is_struct(arrow_type): + return self.visit_struct(arrow_type, data) + elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): + return self.visit_list(arrow_type, data) + elif pa.types.is_fixed_size_list(arrow_type): + return self.visit_list(arrow_type, data) + elif pa.types.is_map(arrow_type): + return self.visit_map(arrow_type, data) + else: + return self.visit_primitive(arrow_type, data) + + def _visit_struct_fields( + self, struct_type: "pa.StructType", data: dict | None + ) -> tuple["pa.StructType", dict]: + """Recursively process struct fields. Default behavior for regular structs.""" + if data is None: + return struct_type, None + + new_fields = [] + new_data = {} + + for field in struct_type: + field_data = data.get(field.name) + new_field_type, new_field_data = self.visit(field.type, field_data) + new_fields.append(pa.field(field.name, new_field_type)) + new_data[field.name] = new_field_data + + return pa.struct(new_fields), new_data + + def _visit_list_elements( + self, list_type: "pa.ListType", data: list | None + ) -> tuple["pa.DataType", list]: + """Recursively process list elements.""" + if data is None: + return list_type, None + + element_type = list_type.value_type + processed_elements = [] + new_element_type = None + + for item in data: + current_element_type, processed_item = self.visit(element_type, item) + processed_elements.append(processed_item) + if new_element_type is None: + new_element_type = current_element_type + + if new_element_type is None: + new_element_type = element_type + + if pa.types.is_large_list(list_type): + return pa.large_list(new_element_type), processed_elements + elif pa.types.is_fixed_size_list(list_type): + return pa.list_(new_element_type, list_type.list_size), processed_elements + else: + return pa.list_(new_element_type), processed_elements + + +class SemanticHashingError(Exception): + """Exception raised when semantic hashing fails.""" + pass + + +class SemanticHashingVisitor(ArrowTypeDataVisitor): + """Visitor that replaces extension-typed columns with their content hashes. + + For each Arrow column whose type is a ``pa.ExtensionType``: + + 1. Look up the corresponding Python type via ``type_converter``. + 2. If the Python type has a semantic hasher registered in ``python_hasher``, + convert the storage value to a Python object and hash it, replacing the + column with a ``pa.large_binary()`` value of the form:: + + + b"::" + content_hash.to_prefixed_digest() + + where ``type_name`` is the extension name with dots replaced by colons + (e.g. ``"orcapod.path"`` → ``"orcapod:path"``), and + ``to_prefixed_digest()`` = ``method_bytes + b":" + digest``. + 3. If no hasher is registered (or the converter doesn't know the type), + return the extension type and storage value unchanged. The downstream + ``StarfixArrowHasher`` / ``ArrowDigester`` will see the full extension + metadata intact and hash it in a type-aware way. + + Args: + type_converter: The active ``UniversalTypeConverter`` for resolving + extension type → Python type and storage → Python conversion. + python_hasher: The active ``SemanticAwarePythonHasher`` for hashing + Python objects. + """ + + def __init__( + self, + type_converter: "UniversalTypeConverter", + python_hasher: "SemanticAwarePythonHasher", + ) -> None: + self._type_converter = type_converter + self._python_hasher = python_hasher + self._current_field_path: list[str] = [] + + def visit_extension( + self, + extension_type: "pa.ExtensionType", + storage_value: Any, + ) -> tuple["pa.DataType", Any]: + """Hash an extension type value to pa.large_binary(), or passthrough.""" + if storage_value is None: + return extension_type, None + + from typing import Any as _Any + + # Resolve extension type → Python type. + python_type = self._type_converter.arrow_type_to_python_type(extension_type) + + # If the converter couldn't resolve to a concrete class, passthrough. + if python_type is _Any or not isinstance(python_type, type): + return extension_type, storage_value + + # Only hash if a semantic hasher is registered for this Python type. + if not self._python_hasher.type_semantic_hasher_registry.has_semantic_hasher( + python_type + ): + return extension_type, storage_value + + # Convert storage value → Python object and hash it. + python_obj = self._type_converter.storage_to_python(storage_value, python_type) + content_hash = self._python_hasher.hash_object(python_obj) + + # Encode as binary: ":::" + # Dots in the extension name → colons (e.g. "orcapod.path" → "orcapod:path"). + # The "::" separator is unambiguous because to_prefixed_digest() uses only ":". + type_name = extension_type.extension_name.replace(".", ":") + hash_bytes = ( + type_name.encode("ascii") + + b"::" + + content_hash.to_prefixed_digest() + ) + return pa.large_binary(), hash_bytes + + def visit_struct( + self, struct_type: "pa.StructType", data: dict | None + ) -> tuple["pa.DataType", Any]: + """Regular struct (no extension identity) — recurse into fields.""" + if data is None: + return struct_type, None + return self._visit_struct_fields(struct_type, data) + + def visit_list( + self, list_type: "pa.ListType", data: list | None + ) -> tuple["pa.DataType", Any]: + """Recurse into list elements.""" + if data is None: + return list_type, None + self._current_field_path.append("[*]") + try: + return self._visit_list_elements(list_type, data) + finally: + self._current_field_path.pop() + + def visit_map( + self, map_type: "pa.MapType", data: dict | None + ) -> tuple["pa.DataType", Any]: + """Pass map types through unchanged.""" + return map_type, data + + def visit_primitive( + self, primitive_type: "pa.DataType", data: Any + ) -> tuple["pa.DataType", Any]: + """Pass primitive types through unchanged.""" + return primitive_type, data + + def _visit_struct_fields( + self, struct_type: "pa.StructType", data: dict | None + ) -> tuple["pa.StructType", dict]: + """Override to add field path tracking for better error messages.""" + if data is None: + return struct_type, None + + new_fields = [] + new_data = {} + + for field in struct_type: + self._current_field_path.append(field.name) + try: + field_data = data.get(field.name) + new_field_type, new_field_data = self.visit(field.type, field_data) + new_fields.append(pa.field(field.name, new_field_type)) + new_data[field.name] = new_field_data + finally: + self._current_field_path.pop() + + return pa.struct(new_fields), new_data +``` + +- [ ] **Step 4: Run tests — verify they pass** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/test_extension_type_hashing.py -x -v +``` + +Expected: all tests pass. + +- [ ] **Step 5: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/visitors.py \ + tests/test_hashing/test_extension_type_hashing.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "feat(visitors): add visit_extension dispatch; rewrite SemanticHashingVisitor for extension types" +``` + +--- + +## Task 9: Update `StarfixArrowHasher`, delete `SemanticArrowHasher` + +**Files:** +- Modify: `src/orcapod/hashing/arrow_hashers.py` + +- [ ] **Step 1: Rewrite `arrow_hashers.py`** + +Delete the entire `SemanticArrowHasher` class. Update `StarfixArrowHasher`: + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyarrow as pa +from starfix import ArrowDigester + +from orcapod.hashing.schema_cleaner import clean_schema_for_hashing, has_extension_metadata +from orcapod.hashing.visitors import SemanticHashingVisitor +from orcapod.types import ContentHash +from orcapod.utils import arrow_utils + +if TYPE_CHECKING: + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher + + +class StarfixArrowHasher: + """Arrow table hasher backed by the starfix-python ``ArrowDigester``. + + Pipeline + -------- + 1. **Semantic pre-processing** — the ``SemanticHashingVisitor`` traverses + every column. Extension-typed columns whose Python type has a registered + semantic hasher are replaced with ``pa.large_binary()`` hash tokens + (e.g. ``Path`` columns are replaced by their file-content hash). + Extension-typed columns without a registered hasher pass through with + their full extension metadata intact. + 2. **Starfix hashing** — ``ArrowDigester.hash_table`` produces a 35-byte + versioned SHA-256 digest that is byte-for-byte identical to the Rust + ``starfix`` crate output. + + Parameters + ---------- + type_converter: + ``UniversalTypeConverter`` used to resolve extension types to Python + types and convert storage values back to Python objects. + semantic_hasher: + ``SemanticAwarePythonHasher`` used to hash Python objects extracted + from extension-typed columns. + hasher_id: + String identifier embedded in every ``ContentHash`` produced by this + hasher. + """ + + def __init__( + self, + type_converter: "UniversalTypeConverter", + semantic_hasher: "SemanticAwarePythonHasher", + hasher_id: str, + ) -> None: + self._type_converter = type_converter + self._semantic_hasher = semantic_hasher + self._hasher_id = hasher_id + + @property + def hasher_id(self) -> str: + return self._hasher_id + + def _process_table_columns(self, table: "pa.Table | pa.RecordBatch") -> "pa.Table": + """Replace semantic-typed columns with their content-hash bytes.""" + new_columns: list[pa.Array] = [] + new_fields: list[pa.Field] = [] + + for i, field in enumerate(table.schema): + # Short-circuit: columns that cannot contain semantic types skip + # the costly Python round-trip. Extension types must pass through + # so visit_extension can process them. + if not ( + isinstance(field.type, pa.ExtensionType) + or pa.types.is_struct(field.type) + or pa.types.is_list(field.type) + or pa.types.is_large_list(field.type) + or pa.types.is_fixed_size_list(field.type) + or pa.types.is_map(field.type) + ): + new_columns.append(table.column(i)) + new_fields.append(field) + continue + + column_data = table.column(i).to_pylist() + visitor = SemanticHashingVisitor(self._type_converter, self._semantic_hasher) + + try: + new_type: pa.DataType | None = None + processed_data: list[Any] = [] + for value in column_data: + processed_type, processed_value = visitor.visit(field.type, value) + if new_type is None and processed_value is not None: + new_type = processed_type + processed_data.append(processed_value) + + if new_type is None: + new_type = field.type + new_columns.append(pa.array(processed_data, type=new_type)) + new_fields.append(field.with_type(new_type)) + + except Exception as exc: + raise RuntimeError( + f"Failed to process column '{field.name}': {exc}" + ) from exc + + return pa.table( + new_columns, + schema=pa.schema(new_fields, metadata=table.schema.metadata), + ) + + def hash_schema(self, schema: "pa.Schema") -> ContentHash: + """Hash an Arrow schema using the starfix canonical algorithm.""" + include_meta = has_extension_metadata(schema) + if include_meta: + schema = clean_schema_for_hashing(schema) + digest = ArrowDigester.hash_schema(schema, include_metadata=include_meta) + return ContentHash(method=self._hasher_id, digest=digest) + + def hash_table(self, table: "pa.Table | pa.RecordBatch") -> ContentHash: + """Hash an Arrow table (or ``RecordBatch``) using starfix.""" + if isinstance(table, pa.RecordBatch): + table = pa.Table.from_batches([table]) + + processed_table = self._process_table_columns(table) + include_meta = has_extension_metadata(processed_table.schema) + if include_meta: + clean_schema = clean_schema_for_hashing(processed_table.schema) + clean_table = pa.Table.from_arrays( + processed_table.columns, schema=clean_schema + ) + else: + clean_table = processed_table + digest = ArrowDigester.hash_table(clean_table, include_metadata=include_meta) + return ContentHash(method=self._hasher_id, digest=digest) +``` + +- [ ] **Step 2: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/hashing/arrow_hashers.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "refactor(arrow_hashers): update StarfixArrowHasher for extension types, delete SemanticArrowHasher" +``` + +--- + +## Task 10: Update `test_starfix_arrow_hasher.py`, run tests + +**Files:** +- Modify: `tests/test_hashing/test_starfix_arrow_hasher.py` + +- [ ] **Step 1: Update `_make_hasher()` and remove `SemanticTypeRegistry` import** + +```python +# Remove this import: +# from orcapod.semantic_types import SemanticTypeRegistry + +# Update _make_hasher(): +def _make_hasher() -> StarfixArrowHasher: + from orcapod.contexts import get_default_context + ctx = get_default_context() + return StarfixArrowHasher( + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, + hasher_id=HASHER_ID, + ) +``` + +- [ ] **Step 2: Run the hashing test suite** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/ -x -v +``` + +Expected: all tests pass (golden digests unchanged for plain-schema tables; extension type tests pass). + +- [ ] **Step 3: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add tests/test_hashing/test_starfix_arrow_hasher.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "test(starfix_arrow_hasher): update _make_hasher() for new constructor, remove SemanticTypeRegistry import" +``` + +--- + +## Task 11: Update `v0.1.json`, `context_schema.json`, and `versioned_hashers.py` + +**Files:** +- Modify: `src/orcapod/contexts/data/v0.1.json` +- Modify: `src/orcapod/contexts/data/schemas/context_schema.json` +- Modify: `src/orcapod/hashing/versioned_hashers.py` + +- [ ] **Step 1: Rewrite `v0.1.json`** + +Key design note: `arrow_hasher` now depends on `semantic_hasher`, and `semantic_hasher` depends on `python_type_semantic_hasher_registry`. To avoid a circular dependency, the `pa.Table`/`pa.RecordBatch` handler entries are **removed** from the registry's handlers list (those entries previously referenced `arrow_hasher`). The JSON construction order is: `file_hasher` → `type_converter` → `function_info_extractor` → `python_type_semantic_hasher_registry` → `semantic_hasher` → `arrow_hasher`. + +```json +{ + "context_key": "std:v0.1:default", + "version": "v0.1", + "description": "Initial stable release with extension type hashing support", + "file_hasher": { + "_class": "orcapod.hashing.file_hashers.BasicFileHasher", + "_config": { + "algorithm": "sha256" + } + }, + "type_converter": { + "_class": "orcapod.semantic_types.universal_converter.UniversalTypeConverter", + "_config": { + "logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ], + "factories": [ + { + "factory": { + "_class": "orcapod.extension_types.dataclass_logical_type_factory.DataclassLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.dataclass", + "python_bases": [{"_type": "builtins.object"}] + }, + { + "factory": { + "_class": "orcapod.extension_types.pydantic_logical_type_factory.PydanticLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.pydantic", + "python_bases": [{"_type": "pydantic.BaseModel"}] + } + ] + } + } + } + }, + "function_info_extractor": { + "_class": "orcapod.hashing.semantic_hashing.function_info_extractors.FunctionSignatureExtractor", + "_config": { + "include_module": true, + "include_defaults": true + } + }, + "python_type_semantic_hasher_registry": { + "_class": "orcapod.hashing.semantic_hashing.type_handler_registry.PythonTypeSemanticHasherRegistry", + "_config": { + "handlers": [ + [{"_type": "builtins.bytes"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesSemanticHasher", "_config": {}}], + [{"_type": "builtins.bytearray"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesSemanticHasher", "_config": {}}], + [{"_type": "pathlib.Path"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.PathSemanticHasher", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "upath.core.UPath"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UPathSemanticHasher", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "uuid.UUID"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UUIDSemanticHasher", "_config": {}}], + [{"_type": "types.FunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionSemanticHasher", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.BuiltinFunctionType"},{"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionSemanticHasher", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.MethodType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionSemanticHasher", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "builtins.type"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.TypeObjectSemanticHasher", "_config": {}}], + [{"_type": "types.GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasSemanticHasher", "_config": {}}], + [{"_type": "types.UnionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UnionTypeSemanticHasher", "_config": {}}], + [{"_type": "typing._GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasSemanticHasher", "_config": {}}], + [{"_type": "typing._SpecialForm"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.SpecialFormSemanticHasher", "_config": {}}] + ] + } + }, + "semantic_hasher": { + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.SemanticAwarePythonHasher", + "_config": { + "hasher_id": "semantic_v0.1", + "type_semantic_hasher_registry": { + "_ref": "python_type_semantic_hasher_registry" + } + } + }, + "arrow_hasher": { + "_class": "orcapod.hashing.arrow_hashers.StarfixArrowHasher", + "_config": { + "hasher_id": "arrow_v0.1", + "type_converter": {"_ref": "type_converter"}, + "semantic_hasher": {"_ref": "semantic_hasher"} + } + }, + "metadata": { + "created_date": "2026-06-24", + "author": "OrcaPod Core Team", + "changelog": [ + "Initial release with Path semantic type support", + "Basic SHA-256 hashing for files and objects", + "Arrow logical serialization method", + "Introduced arrow_v0.1 StarfixArrowHasher using starfix ArrowDigester for cross-language-compatible Arrow hashing", + "Hard cut: replaced shape-based SemanticTypeRegistry with extension-type hashing; renamed all hashing classes to clearer names" + ] + } +} +``` + +- [ ] **Step 2: Update `context_schema.json`** + +Two changes: +1. Remove the `semantic_registry` property from `properties`. +2. Rename `type_handler_registry` → `python_type_semantic_hasher_registry` in `properties`. + +```json +"python_type_semantic_hasher_registry": { + "$ref": "#/$defs/objectspec", + "description": "ObjectSpec for the PythonTypeSemanticHasherRegistry used by the semantic hasher" +}, +``` + +Also update the `examples` section references and remove the `"semantic_registry"` entry. + +- [ ] **Step 3: Update `versioned_hashers.py`** + +```python +""" +Versioned hasher factories for OrcaPod. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from orcapod.protocols import hashing_protocols as hp + +logger = logging.getLogger(__name__) + +_CURRENT_SEMANTIC_HASHER_ID = "semantic_v0.1" +_CURRENT_ARROW_HASHER_ID = "arrow_v0.1" + + +def get_versioned_semantic_hasher( + hasher_id: str = _CURRENT_SEMANTIC_HASHER_ID, + strict: bool = True, + type_semantic_hasher_registry: "Any | None" = None, +) -> hp.SemanticHasherProtocol: + """Return a SemanticAwarePythonHasher configured for the current version. + + Parameters + ---------- + hasher_id: + Identifier embedded in every ContentHash produced by this hasher. + strict: + When True raises TypeError for unhandled types. When False falls back + to a best-effort string representation. + type_semantic_hasher_registry: + Optional ``PythonTypeSemanticHasherRegistry`` to inject. When None the + global default registry is used. + """ + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher + + if type_semantic_hasher_registry is None: + from orcapod.hashing.semantic_hashing.type_handler_registry import ( + get_default_python_type_semantic_hasher_registry, + ) + type_semantic_hasher_registry = get_default_python_type_semantic_hasher_registry() + + logger.debug( + "get_versioned_semantic_hasher: creating SemanticAwarePythonHasher " + "(hasher_id=%r, strict=%r)", + hasher_id, + strict, + ) + return SemanticAwarePythonHasher( + hasher_id=hasher_id, + type_semantic_hasher_registry=type_semantic_hasher_registry, + strict=strict, + ) + + +def get_versioned_semantic_arrow_hasher( + hasher_id: str = _CURRENT_ARROW_HASHER_ID, +) -> hp.ArrowHasherProtocol: + """Return a StarfixArrowHasher configured for the current version. + + Sources ``type_converter`` and ``semantic_hasher`` from the default + ``DataContext`` so that the arrow hasher is consistent with all other + versioned components. + + Parameters + ---------- + hasher_id: + Identifier embedded in every ContentHash produced by this hasher. + """ + from orcapod.hashing.arrow_hashers import StarfixArrowHasher + from orcapod.contexts import resolve_context + + ctx = resolve_context(None) # default context + logger.debug( + "get_versioned_semantic_arrow_hasher: creating StarfixArrowHasher " + "(hasher_id=%r)", + hasher_id, + ) + return StarfixArrowHasher( + hasher_id=hasher_id, + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, + ) +``` + +- [ ] **Step 4: Run the full test suite (except test_semantic_types)** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/ tests/test_extension_types/ tests/test_core/ -x -v +``` + +Expected: all tests pass. + +- [ ] **Step 5: Commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add src/orcapod/contexts/data/v0.1.json \ + src/orcapod/contexts/data/schemas/context_schema.json \ + src/orcapod/hashing/versioned_hashers.py +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "feat(v0.1): wire extension type hashing into default context; remove semantic_registry" +``` + +--- + +## Task 12: Delete old semantic type system + grep sweep + final test run + +**Files:** +- Delete: `src/orcapod/semantic_types/semantic_struct_converters.py` +- Delete: `src/orcapod/semantic_types/semantic_registry.py` +- Delete: `tests/test_semantic_types/` (all 9 files) +- Delete: `tests/test_hashing/test_file_hashing_consistency.py` +- Modify: `src/orcapod/semantic_types/__init__.py` +- Modify: `src/orcapod/protocols/semantic_types_protocols.py` + +- [ ] **Step 1: Delete old source files** + +```bash +rm /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/src/orcapod/semantic_types/semantic_struct_converters.py +rm /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/src/orcapod/semantic_types/semantic_registry.py +``` + +- [ ] **Step 2: Update `semantic_types/__init__.py`** — remove `SemanticTypeRegistry` export + +```python +from .universal_converter import UniversalTypeConverter +from .type_inference import infer_python_schema_from_pylist_data + +__all__ = [ + "UniversalTypeConverter", + "infer_python_schema_from_pylist_data", +] +``` + +- [ ] **Step 3: Remove `SemanticStructConverterProtocol` from `semantic_types_protocols.py`** + +Delete the `SemanticStructConverterProtocol` class and any imports that only support it. Keep `TypeConverterProtocol` and all other classes. + +- [ ] **Step 4: Delete old test files** + +```bash +rm /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/tests/test_hashing/test_file_hashing_consistency.py +rm -r /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/tests/test_semantic_types/ +``` + +- [ ] **Step 5: Grep sweep for stale references** + +```bash +grep -rn \ + "SemanticTypeRegistry\|semantic_registry\|SemanticStructConverter\ +\|BaseSemanticHasher\|TypeHandlerRegistry\|BuiltinTypeHandlerRegistry\ +\|TypeHandlerProtocol\|PathContentHandler\|UPathContentHandler\ +\|UUIDHandler\|BytesHandler\|FunctionHandler\|TypeObjectHandler\ +\|SpecialFormHandler\|GenericAliasHandler\|UnionTypeHandler\|ArrowTableHandler\ +\|SchemaHandler\|register_builtin_handlers\|get_default_type_handler_registry\ +\|type_handler_registry\|get_handler\b\|has_handler\b\|SemanticArrowHasher" \ + /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/src/ \ + /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python/tests/ \ + 2>/dev/null +``` + +Expected: zero matches (fix any that appear before continuing). + +- [ ] **Step 6: Run full test suite** + +```bash +uv run --project /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + pytest tests/test_hashing/ tests/test_extension_types/ tests/test_core/ -x -v +``` + +Expected: all tests pass. + +- [ ] **Step 7: Final commit** + +```bash +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + add -u +git -C /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python \ + commit -m "feat(PLT-1660): hard cut — delete SemanticTypeRegistry and old struct-based hashing system" +``` + +--- + +## Self-Review + +**Spec coverage:** +- ✅ §1 `visit_extension` added to `ArrowTypeDataVisitor`, `visit()` updated (Task 8) +- ✅ §2 `SemanticHashingVisitor` rewritten with binary encoding (Task 8) +- ✅ §3 `StarfixArrowHasher` constructor updated + short-circuit + `SemanticArrowHasher` deleted (Task 9) +- ✅ §4 `SemanticArrowHasher` deleted (Task 9) +- ✅ §5 All class/method renames applied (Tasks 1–6) +- ✅ §6 Protocol tightened: `hash() -> ContentHash` (Tasks 1, 3, 4) +- ✅ §7 `v0.1.json` updated (Task 11) — note: `pa.Table`/`pa.RecordBatch` handlers removed to break circular dep +- ✅ §8 `context_schema.json` updated (Task 11) +- ✅ §9 `DataContext.core` docstring updated (Task 5) +- ✅ §10 `versioned_hashers.py` sources from context (Task 11) +- ✅ Files to delete: all covered (Task 12) +- ✅ Files to update: covered across Tasks 1–11 + +**Circular dependency note (§7 deviation):** The spec says to add `"semantic_hasher": {"_ref": "semantic_hasher"}` to `arrow_hasher._config`. This is correct and implemented. However, to avoid a construction-order cycle (`arrow_hasher` → `semantic_hasher` → `registry` → `arrow_hasher` via `ArrowTableSemanticHasher`), the `pa.Table` and `pa.RecordBatch` handler entries are removed from the `python_type_semantic_hasher_registry` handlers list in `v0.1.json`. These handlers depended on `arrow_hasher` creating the cycle. The `register_builtin_python_type_semantic_hashers()` function still supports them when `arrow_hasher` is passed explicitly (e.g., for custom registry construction in tests). + +**Type consistency check:** +- `SemanticAwarePythonHasher.__init__` takes `type_semantic_hasher_registry` → `v0.1.json` uses key `type_semantic_hasher_registry` ✅ +- `SemanticHashingVisitor.__init__` takes `type_converter, python_hasher` → `_process_table_columns` passes `self._type_converter, self._semantic_hasher` ✅ +- `StarfixArrowHasher.__init__` takes `type_converter, semantic_hasher, hasher_id` → `versioned_hashers.py` passes these by keyword ✅ +- `PythonTypeSemanticHasherRegistry.get_semantic_hasher(obj)` → `SemanticAwarePythonHasher.hash_object()` calls this ✅ +- `PythonTypeSemanticHasherRegistry.has_semantic_hasher(target_type)` → `SemanticHashingVisitor.visit_extension()` calls this ✅ diff --git a/superpowers/plans/2026-06-24-rename-semantic-hasher-to-handler.md b/superpowers/plans/2026-06-24-rename-semantic-hasher-to-handler.md new file mode 100644 index 00000000..d33489a3 --- /dev/null +++ b/superpowers/plans/2026-06-24-rename-semantic-hasher-to-handler.md @@ -0,0 +1,422 @@ +# Rename *SemanticHasher → *Handler, PythonTypeSemanticHasherRegistry → PythonTypeHandlerRegistry + +> **For agentic workers:** REQUIRED SUB-SKILL: Use sensei:subagent-driven-development (recommended) or sensei:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Mechanically rename all `*SemanticHasher` handler classes to `*Handler`, all `PythonTypeSemanticHasherRegistry` variants to `PythonTypeHandlerRegistry`, and the `type_semantic_hasher_registry` param/property to `type_handler_registry` — no logic changes. + +**Architecture:** Pure find-and-replace of identifiers across ~10 source files and 2 JSON configs. Every old name maps 1-to-1 to a new name. No logic, no interface changes, no backward-compat shims (greenfield project). + +**Tech Stack:** Python, JSON, uv/pytest + +--- + +## File Map + +| File | What changes | +|---|---| +| `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` | 11 class names + function name + docstring/string literals | +| `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` | 3 class/function names + docstrings + internal log strings | +| `src/orcapod/hashing/semantic_hashing/semantic_hasher.py` | param + property name `type_semantic_hasher_registry` → `type_handler_registry` + docstring | +| `src/orcapod/hashing/semantic_hashing/__init__.py` | imports + `__all__` | +| `src/orcapod/hashing/__init__.py` | imports + `__all__` | +| `src/orcapod/hashing/defaults.py` | function name + import + docstring | +| `src/orcapod/hashing/versioned_hashers.py` | param name + import | +| `src/orcapod/protocols/hashing_protocols.py` | property name in `SemanticHasherProtocol` + TYPE_CHECKING import | +| `src/orcapod/contexts/data/v0.1.json` | top-level key, `_class` values, `_ref` value, sub-key | +| `src/orcapod/contexts/data/schemas/context_schema.json` | property key | +| `tests/test_hashing/test_semantic_hasher.py` | imports + usage | +| `tests/test_hashing/test_uuid_handler.py` | imports + usage | +| `tests/test_hashing/test_extension_type_hashing.py` | no old names (already clean) | +| `test-objective/unit/test_hashing.py` | imports, class names, type annotations, comments | + +--- + +## Rename Reference Table + +### Handler classes (builtin_handlers.py + all callers) + +| Old | New | +|---|---| +| `PathSemanticHasher` | `PathHandler` | +| `UPathSemanticHasher` | `UPathHandler` | +| `UUIDSemanticHasher` | `UUIDHandler` | +| `BytesSemanticHasher` | `BytesHandler` | +| `FunctionSemanticHasher` | `FunctionHandler` | +| `TypeObjectSemanticHasher` | `TypeObjectHandler` | +| `SpecialFormSemanticHasher` | `SpecialFormHandler` | +| `GenericAliasSemanticHasher` | `GenericAliasHandler` | +| `UnionTypeSemanticHasher` | `UnionTypeHandler` | +| `ArrowTableSemanticHasher` | `ArrowTableHandler` | +| `SchemaSemanticHasher` | `SchemaHandler` | +| `register_builtin_python_type_semantic_hashers` | `register_builtin_python_type_handlers` | + +### Registry classes (type_handler_registry.py + all callers) + +| Old | New | +|---|---| +| `PythonTypeSemanticHasherRegistry` | `PythonTypeHandlerRegistry` | +| `BuiltinPythonTypeSemanticHasherRegistry` | `BuiltinPythonTypeHandlerRegistry` | +| `get_default_python_type_semantic_hasher_registry` | `get_default_python_type_handler_registry` | + +### Parameter/property (semantic_hasher.py + all callers) + +| Old | New | +|---|---| +| `type_semantic_hasher_registry` | `type_handler_registry` | + +--- + +## Task 1: Rename class definitions and internal strings in `builtin_handlers.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/builtin_handlers.py` + +- [ ] **Step 1: Apply all renames in builtin_handlers.py** + + Changes needed (all are identifier or string-literal renames only): + - Module docstring: update all `*SemanticHasher` names and `register_builtin_python_type_semantic_hashers` + - TYPE_CHECKING import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - All 11 class definitions: `class PathSemanticHasher` → `class PathHandler`, etc. + - Error messages inside class bodies: e.g. `"PathSemanticHasher: path does not exist"` → `"PathHandler: path does not exist"` + - `logger.debug` strings: e.g. `"PathSemanticHasher: hashing file content"` → `"PathHandler: hashing file content"` + - Function `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Docstring inside that function: update `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Final `logger.debug` string: `"register_builtin_python_type_semantic_hashers: registered %d hashers"` → `"register_builtin_python_type_handlers: registered %d hashers"` + +- [ ] **Step 2: Verify file parses correctly** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.semantic_hashing import builtin_handlers; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 2: Rename class definitions in `type_handler_registry.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/type_handler_registry.py` + +- [ ] **Step 1: Apply all renames in type_handler_registry.py** + + Changes needed: + - Module docstring: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Class `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - `__repr__` method: `"PythonTypeSemanticHasherRegistry(registered=..."` → `"PythonTypeHandlerRegistry(registered=..."` + - `logger.debug` strings that mention `PythonTypeSemanticHasherRegistry` + - Function `get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - The function body's import: `get_default_python_type_semantic_hasher_registry as _get` → `get_default_python_type_handler_registry as _get` + - Class `BuiltinPythonTypeSemanticHasherRegistry` → `BuiltinPythonTypeHandlerRegistry` + - Docstring: `"A PythonTypeSemanticHasherRegistry pre-populated..."` → `"A PythonTypeHandlerRegistry pre-populated..."` + - `super().__init__()` call — no change needed + - Import inside `__init__`: `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Call: `register_builtin_python_type_semantic_hashers(self, ...)` → `register_builtin_python_type_handlers(self, ...)` + +- [ ] **Step 2: Verify file parses correctly** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeHandlerRegistry; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 3: Rename param/property in `semantic_hasher.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/semantic_hasher.py` + +- [ ] **Step 1: Apply renames in semantic_hasher.py** + + Changes needed: + - Import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Docstring parameter: `type_semantic_hasher_registry:` → `type_handler_registry:` + - Constructor param: `type_semantic_hasher_registry: PythonTypeHandlerRegistry | None = None` → `type_handler_registry: PythonTypeHandlerRegistry | None = None` + - Constructor body: `if type_semantic_hasher_registry is None:` → `if type_handler_registry is None:` + - Constructor body: `from orcapod.hashing.defaults import get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - Constructor body: `self._registry = get_default_python_type_semantic_hasher_registry()` → `get_default_python_type_handler_registry()` + - Constructor body: `else: self._registry = type_semantic_hasher_registry` → `else: self._registry = type_handler_registry` + - Property `type_semantic_hasher_registry` → `type_handler_registry` + - Property docstring: `"Return the ``PythonTypeSemanticHasherRegistry``..."` → `"Return the ``PythonTypeHandlerRegistry``..."` + - Property return type annotation: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Error message in `_handle_unknown`: `"via the PythonTypeSemanticHasherRegistry or"` → `"via the PythonTypeHandlerRegistry or"` + +- [ ] **Step 2: Verify file parses correctly** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 4: Update `semantic_hashing/__init__.py` + +**Files:** +- Modify: `src/orcapod/hashing/semantic_hashing/__init__.py` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - Module docstring: all `*SemanticHasher` names → `*Handler` equivalents + - Import from `builtin_handlers`: `BytesSemanticHasher` → `BytesHandler`, etc.; `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Import from `type_handler_registry`: `BuiltinPythonTypeSemanticHasherRegistry` → `BuiltinPythonTypeHandlerRegistry`, `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - `__all__`: update all entries to new names + +- [ ] **Step 2: Verify** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.semantic_hashing import PathHandler, PythonTypeHandlerRegistry, register_builtin_python_type_handlers; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 5: Update `hashing/__init__.py` + +**Files:** +- Modify: `src/orcapod/hashing/__init__.py` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - Module docstring: update all old names + - Import from `defaults`: `get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - Import from `builtin_handlers`: `BytesSemanticHasher` → `BytesHandler`, etc.; `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Import from `type_handler_registry`: `BuiltinPythonTypeSemanticHasherRegistry` → `BuiltinPythonTypeHandlerRegistry`, `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - `__all__`: update all entries to new names + +- [ ] **Step 2: Verify** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing import PythonTypeHandlerRegistry, get_default_python_type_handler_registry, BytesHandler; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 6: Update `hashing/defaults.py` + +**Files:** +- Modify: `src/orcapod/hashing/defaults.py` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - Import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Function name: `get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - Return type annotation: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Docstring: update class name references + - Function body: `get_default_context().semantic_hasher.type_semantic_hasher_registry` → `get_default_context().semantic_hasher.type_handler_registry` + +- [ ] **Step 2: Verify** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.defaults import get_default_python_type_handler_registry; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 7: Update `hashing/versioned_hashers.py` + +**Files:** +- Modify: `src/orcapod/hashing/versioned_hashers.py` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - Function param: `type_semantic_hasher_registry: "Any | None" = None` → `type_handler_registry: "Any | None" = None` + - Docstring param description: `type_semantic_hasher_registry:` → `type_handler_registry:` + - Import inside function: `get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - Variable: `type_semantic_hasher_registry = get_default_python_type_semantic_hasher_registry()` → `type_handler_registry = get_default_python_type_handler_registry()` + - `SemanticAwarePythonHasher(... type_semantic_hasher_registry=type_semantic_hasher_registry ...)` → `... type_handler_registry=type_handler_registry ...` + +- [ ] **Step 2: Verify** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.hashing.versioned_hashers import get_versioned_semantic_hasher; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 8: Update `protocols/hashing_protocols.py` + +**Files:** +- Modify: `src/orcapod/protocols/hashing_protocols.py` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - TYPE_CHECKING import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - `SemanticHasherProtocol.type_semantic_hasher_registry` property → `type_handler_registry` + - Property docstring: `"Return the PythonTypeSemanticHasherRegistry..."` → `"Return the PythonTypeHandlerRegistry..."` + - Property return type annotation: `"PythonTypeSemanticHasherRegistry"` → `"PythonTypeHandlerRegistry"` + +- [ ] **Step 2: Verify** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "from orcapod.protocols.hashing_protocols import SemanticHasherProtocol; print('OK')" + ``` + Expected: `OK` + +--- + +## Task 9: Update `contexts/data/v0.1.json` + +**Files:** +- Modify: `src/orcapod/contexts/data/v0.1.json` + +- [ ] **Step 1: Apply renames** + + Changes needed (4 renames): + 1. Top-level key `"python_type_semantic_hasher_registry"` → `"python_type_handler_registry"` + 2. All `"_class"` values with `*SemanticHasher` suffix — e.g.: + - `"...builtin_handlers.BytesSemanticHasher"` → `"...builtin_handlers.BytesHandler"` + - `"...builtin_handlers.PathSemanticHasher"` → `"...builtin_handlers.PathHandler"` + - `"...builtin_handlers.UPathSemanticHasher"` → `"...builtin_handlers.UPathHandler"` + - `"...builtin_handlers.UUIDSemanticHasher"` → `"...builtin_handlers.UUIDHandler"` + - `"...builtin_handlers.FunctionSemanticHasher"` → `"...builtin_handlers.FunctionHandler"` + - `"...builtin_handlers.TypeObjectSemanticHasher"` → `"...builtin_handlers.TypeObjectHandler"` + - `"...builtin_handlers.GenericAliasSemanticHasher"` → `"...builtin_handlers.GenericAliasHandler"` + - `"...builtin_handlers.UnionTypeSemanticHasher"` → `"...builtin_handlers.UnionTypeHandler"` + - `"...builtin_handlers.SpecialFormSemanticHasher"` → `"...builtin_handlers.SpecialFormHandler"` + - `"...builtin_handlers.ArrowTableSemanticHasher"` → `"...builtin_handlers.ArrowTableHandler"` + - `"...type_handler_registry.PythonTypeSemanticHasherRegistry"` → `"...type_handler_registry.PythonTypeHandlerRegistry"` + 3. Inside `semantic_hasher._config`: sub-key `"type_semantic_hasher_registry"` → `"type_handler_registry"` + 4. Inside `semantic_hasher._config.type_handler_registry`: `"_ref": "python_type_semantic_hasher_registry"` → `"_ref": "python_type_handler_registry"` + +- [ ] **Step 2: Verify JSON is valid and context loads** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "import json; json.load(open('src/orcapod/contexts/data/v0.1.json')); print('JSON OK')" + uv run python -c "from orcapod.contexts import get_default_context; ctx = get_default_context(); print('Context OK')" + ``` + Expected: `JSON OK` then `Context OK` + +--- + +## Task 10: Update `contexts/data/schemas/context_schema.json` + +**Files:** +- Modify: `src/orcapod/contexts/data/schemas/context_schema.json` + +- [ ] **Step 1: Apply renames** + + Changes needed: + - Property key `"python_type_semantic_hasher_registry"` → `"python_type_handler_registry"` (in `properties` section) + - Description string within that property: `"ObjectSpec for the PythonTypeSemanticHasherRegistry..."` → `"ObjectSpec for the PythonTypeHandlerRegistry..."` + - In the `examples` section: `"type_semantic_hasher_registry"` sub-key → `"type_handler_registry"`, and `"_ref": "python_type_semantic_hasher_registry"` → `"_ref": "python_type_handler_registry"` + +- [ ] **Step 2: Verify JSON is valid** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -c "import json; json.load(open('src/orcapod/contexts/data/schemas/context_schema.json')); print('Schema JSON OK')" + ``` + Expected: `Schema JSON OK` + +--- + +## Task 11: Update test files + +**Files:** +- Modify: `tests/test_hashing/test_semantic_hasher.py` +- Modify: `tests/test_hashing/test_uuid_handler.py` +- Modify: `test-objective/unit/test_hashing.py` + +- [ ] **Step 1: Update `tests/test_hashing/test_semantic_hasher.py`** + + Changes needed: + - Import: `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - Import: `get_default_python_type_semantic_hasher_registry` → `get_default_python_type_handler_registry` + - `make_hasher` body: `registry = PythonTypeSemanticHasherRegistry()` → `PythonTypeHandlerRegistry()`, `register_builtin_python_type_semantic_hashers(registry)` → `register_builtin_python_type_handlers(registry)`, `type_semantic_hasher_registry=registry` → `type_handler_registry=registry` + - All other usages of these names throughout the file (type annotations, variable names, docstrings, comments) + +- [ ] **Step 2: Update `tests/test_hashing/test_uuid_handler.py`** + + Changes needed: + - Import: `register_builtin_python_type_semantic_hashers` → `register_builtin_python_type_handlers` + - Import: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - `_make_hasher` body: same pattern as above + - `type_semantic_hasher_registry=registry` → `type_handler_registry=registry` + +- [ ] **Step 3: Update `test-objective/unit/test_hashing.py`** + + Changes needed (this file has many occurrences — all follow the same pattern): + - Imports: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry`, `BuiltinPythonTypeSemanticHasherRegistry` → `BuiltinPythonTypeHandlerRegistry` + - All fixture/function type annotations: `PythonTypeSemanticHasherRegistry` → `PythonTypeHandlerRegistry` + - All constructor calls: `type_semantic_hasher_registry=registry` → `type_handler_registry=registry` + - All class names in test bodies: `PythonTypeSemanticHasherRegistry()` → `PythonTypeHandlerRegistry()` + - All `BuiltinPythonTypeSemanticHasherRegistry()` → `BuiltinPythonTypeHandlerRegistry()` + - All comments/docstrings mentioning old names + +- [ ] **Step 4: Verify test files parse** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run python -m py_compile tests/test_hashing/test_semantic_hasher.py && echo "OK" + uv run python -m py_compile tests/test_hashing/test_uuid_handler.py && echo "OK" + uv run python -m py_compile test-objective/unit/test_hashing.py && echo "OK" + ``` + Expected: three `OK` lines + +--- + +## Task 12: Run tests and commit + +- [ ] **Step 1: Run hashing tests** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run pytest tests/test_hashing/ -x -q + ``` + Expected: all tests pass + +- [ ] **Step 2: Run full test suite (excluding deleted semantic types)** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + uv run pytest tests/ -x -q --ignore=tests/test_semantic_types + ``` + Expected: all tests pass + +- [ ] **Step 3: Confirm no remaining old names in source** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + grep -rn "PathSemanticHasher\|UPathSemanticHasher\|UUIDSemanticHasher\|BytesSemanticHasher\|FunctionSemanticHasher\|TypeObjectSemanticHasher\|SpecialFormSemanticHasher\|GenericAliasSemanticHasher\|UnionTypeSemanticHasher\|ArrowTableSemanticHasher\|SchemaSemanticHasher\|PythonTypeSemanticHasherRegistry\|BuiltinPythonTypeSemanticHasherRegistry\|get_default_python_type_semantic_hasher_registry\|register_builtin_python_type_semantic_hashers\|type_semantic_hasher_registry" src/ tests/ test-objective/ --include="*.py" --include="*.json" | grep -v "^Binary" + ``` + Expected: no matches (zero lines) + +- [ ] **Step 4: Commit** + + ```bash + cd /home/kurouto/kurouto-jobs/dc15d84f-7281-48b5-9e17-435e9a04f175/orcapod-python + git add src/orcapod/hashing/semantic_hashing/builtin_handlers.py + git add src/orcapod/hashing/semantic_hashing/type_handler_registry.py + git add src/orcapod/hashing/semantic_hashing/semantic_hasher.py + git add src/orcapod/hashing/semantic_hashing/__init__.py + git add src/orcapod/hashing/__init__.py + git add src/orcapod/hashing/defaults.py + git add src/orcapod/hashing/versioned_hashers.py + git add src/orcapod/protocols/hashing_protocols.py + git add src/orcapod/contexts/data/v0.1.json + git add src/orcapod/contexts/data/schemas/context_schema.json + git add tests/test_hashing/test_semantic_hasher.py + git add tests/test_hashing/test_uuid_handler.py + git add test-objective/unit/test_hashing.py + git add superpowers/plans/2026-06-24-rename-semantic-hasher-to-handler.md + git commit -m "refactor(hashing): rename *SemanticHasher → *Handler, PythonTypeSemanticHasherRegistry → PythonTypeHandlerRegistry" + ``` diff --git a/superpowers/plans/2026-06-25-plt-1663-merge-extension-type-system-into-main.md b/superpowers/plans/2026-06-25-plt-1663-merge-extension-type-system-into-main.md new file mode 100644 index 00000000..12f8d3f3 --- /dev/null +++ b/superpowers/plans/2026-06-25-plt-1663-merge-extension-type-system-into-main.md @@ -0,0 +1,41 @@ +# PLT-1663: Merge extension-type-system → main + +## Overview + +The `extension-type-system` integration branch contains all work from PLT-1652 through +PLT-1660, PLT-1668, and PLT-1672. This plan covers the final steps to bring the branch +up-to-date with `main` and create the merge PR. + +## Situation + +- `extension-type-system` is 205 commits ahead of `main` +- It is 5 commits **behind** `main` (all PLT-1773: pyspiral `0.11.7 → 0.14.9` upgrade) +- The missing commits cause `spiral-integration` CI to fail (external service issue, + not a code bug — fixed by the pyspiral version bump on main) +- All other CI checks pass (unit tests 3.11/3.12, license check) +- Code audit: all old naming patterns removed from production code + - `ExtensionTypeConverter` — gone ✅ + - `ExtensionTypeRegistry` — gone ✅ + - `SemanticTypeRegistry` — only in v0.1.json changelog comment ✅ + - `BaseSemanticHasher` — only in v0.1.json changelog comment ✅ + - Shape-based code — only in explanatory comments ✅ + +## Steps + +1. **Rebase** `extension-type-system` onto `origin/main` + - Brings in 5 PLT-1773 commits (pyspiral fix + lock file updates) + - No conflicts expected (verified via dry-run) + - Will fix the `spiral-integration` CI failure + +2. **Force-push** `extension-type-system` to origin + - Required after rebase; targets feature branch only (not main) + +3. **Create PR** `extension-type-system` → `main` + - Comprehensive description listing all sub-issues resolved + - References PLT-1663 and all related issues (PLT-1652 through PLT-1660, PLT-1668, PLT-1672) + +## Success Criteria + +- CI passes on the updated `extension-type-system` branch +- PR is open and ready for review +- PR description references PLT-1663 and all sub-issues diff --git a/superpowers/specs/2026-06-14-extension-type-registry-design.md b/superpowers/specs/2026-06-14-extension-type-registry-design.md new file mode 100644 index 00000000..788b872f --- /dev/null +++ b/superpowers/specs/2026-06-14-extension-type-registry-design.md @@ -0,0 +1,277 @@ +# ExtensionTypeRegistry Design + +**Date:** 2026-06-14 +**Linear issue:** PLT-1653 +**Status:** Approved + +--- + +## Overview + +The `extension_types/` subpackage has a protocol (`ExtensionTypeConverter`) but no registry. +This spec adds `ExtensionTypeRegistry` — a class that maps `extension_name` strings to converter +instances and, as a side effect of each `register()` call, populates both PyArrow's and Polars' +process-global extension type registries so that columns using these types round-trip correctly +through Arrow IPC, Parquet, and Polars DataFrames. + +--- + +## Goals & Success Criteria + +- `ExtensionTypeRegistry.register(converter)` stores the converter and registers the extension + type in both PyArrow and Polars global registries in a single call. +- Registering a converter with a duplicate `extension_name` raises a clear `ValueError`. +- Converters are retrievable by `extension_name` (primary lookup) or `python_type` (secondary, + for the write path). Subclass relationships are honoured in the python-type lookup. +- A module-level `extension_type_registry` instance is created when + `orcapod.extension_types` is imported. It starts empty; PLT-1656 adds the built-in + registrations (`Path`, `UPath`, `UUID`). +- `pyproject.toml` is updated from `polars>=1.31.0` to `polars>=1.36.0`, the minimum version + that ships `pl.BaseExtension` and `pl.register_extension_type`. + +--- + +## Architecture + +### File map + +| File | Change | +|---|---| +| `pyproject.toml` | Update `polars>=1.31.0` → `polars>=1.36.0` | +| `src/orcapod/extension_types/registry.py` | **New** — `ExtensionTypeRegistry` class + private helpers | +| `src/orcapod/extension_types/__init__.py` | Export `ExtensionTypeRegistry`; create `extension_type_registry` | +| `tests/test_extension_types/test_registry.py` | **New** — unit and integration tests | + +--- + +## `registry.py` Module + +### Internal storage + +```python +self._by_name: dict[str, ExtensionTypeConverter] +self._by_python_type: dict[type, ExtensionTypeConverter] +``` + +Both dicts are populated together on every `register()` call. Neither has a reverse mapping +(no need to look up `extension_name` from `python_type` — that path is not required by this +issue). + +### Public API + +```python +class ExtensionTypeRegistry: + def register(self, converter: ExtensionTypeConverter) -> None + def get_converter_for_name(self, name: str) -> ExtensionTypeConverter | None + def get_converter_for_python_type(self, python_type: type) -> ExtensionTypeConverter | None + def has_extension_name(self, name: str) -> bool + def has_python_type(self, python_type: type) -> bool + def list_extension_names(self) -> list[str] + def list_python_types(self) -> list[type] +``` + +**`register(converter)`** — the only mutating method: + +1. Look up `converter.extension_name` in `_by_name`. If found, raise: + `ValueError: Extension type '{name}' is already registered.` +2. Store `_by_name[name] = converter` and `_by_python_type[converter.python_type] = converter`. +3. Call `_register_arrow_ext_type(converter)`. +4. Call `_register_polars_ext_type(converter)`. + +**`get_converter_for_python_type(python_type)`** — exact match first, then `issubclass` scan. +Returns the first registered type for which `issubclass(python_type, registered_type)` is true. +If multiple registered types are superclasses of `python_type`, the one encountered first in +insertion order wins (Python 3.7+ dict ordering). Returns `None` if nothing matches. + +All other public methods are straightforward dict lookups or list returns. + +### Module-level shadow dicts + +Two module-level dicts track what this code has registered in the process-global PA and Polars +registries: + +```python +_ARROW_REGISTRY: dict[str, tuple[pa.DataType, bytes]] = {} +# extension_name → (storage_type, metadata_bytes) + +_POLARS_REGISTRY: dict[str, tuple[pl.DataType, str | None]] = {} +# extension_name → (pl_storage_dtype, metadata_str) +``` + +These are the only mechanism used for equivalence checking. Neither PyArrow nor Polars exposes +a stable public API for looking up a previously registered extension type by name — PyArrow has +no such API at all, and Polars' `get_extension_type` is marked `@unstable()`. Maintaining our +own dicts avoids any dependency on library internals and keeps correctness entirely within our +control. + +Limitation: equivalence can only be verified for types registered via `ExtensionTypeRegistry`. +A type registered externally (directly via `pa.register_extension_type` or +`pl.register_extension_type`, bypassing our code) will not appear in these dicts. When a +subsequent `register()` call hits the library-level duplicate error for such a name, we raise +rather than silently continuing. This is intentional: without knowing what was registered +externally we cannot guarantee that the same extension name maps to the same Python type and +underlying storage type. Silently proceeding could cause silent data corruption or misrouted +conversions at read time. + +### Private helpers + +**`_register_arrow_ext_type(converter)`** + +1. Compute `metadata = converter.extension_metadata or b""` and `storage = converter.storage_type`. +2. If `converter.extension_name` is in `_ARROW_REGISTRY`: + - Compare `(existing_storage, existing_metadata)` with `(storage, metadata)` using `==`. + - Match → return immediately (idempotent; safe for module reload and test-suite reuse). + - Mismatch → raise `ValueError` showing both the existing and attempted values. +3. Dynamically create a `pa.ExtensionType` subclass via `type()`: + +```python +# Pseudocode for the dynamically created class +class _ArrowExt_(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, storage, converter.extension_name) + + def __arrow_ext_serialize__(self) -> bytes: + return metadata # captured from converter at registration time + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() +``` + +4. Call `pa.register_extension_type(instance)`. If PyArrow raises `ArrowKeyError`, the name + was registered externally — re-raise as `ValueError` with a message explaining that the + name is already taken by an external registration and equivalence cannot be verified. +5. On success: `_ARROW_REGISTRY[name] = (storage, metadata)`. + +Name sanitization: replace all non-alphanumeric characters with `_` (e.g. +`pathlib.Path` → `_ArrowExt_pathlib_Path`). Cosmetic only — PyArrow identifies types by +`extension_name`, not by class name. + +**`_register_polars_ext_type(converter)`** + +1. Derive Polars storage dtype by converting an empty PA array: + +```python +pl_storage = pl.from_arrow(pa.array([], type=converter.storage_type)).dtype +``` + +This handles all Arrow → Polars mappings (`pa.large_utf8()` → `pl.String`, +`pa.binary(16)` → `pl.Binary`, `pa.struct(...)` → `pl.Struct({...})`) without a +manually-maintained table. + +2. Compute `metadata_str = converter.extension_metadata.decode("utf-8") if converter.extension_metadata else None`. +3. If `converter.extension_name` is in `_POLARS_REGISTRY`: + - Compare `(existing_pl_storage, existing_metadata_str)` with `(pl_storage, metadata_str)` using `==`. + - Match → return immediately (idempotent). + - Mismatch → raise `ValueError` showing both the existing and attempted values. +4. Dynamically create a `pl.BaseExtension` subclass via `type()`: + +```python +# Pseudocode for the dynamically created class +class _PolarsExt_(pl.BaseExtension): + def __init__(self): + super().__init__(converter.extension_name, pl_storage, metadata_str) + + @classmethod + def ext_from_params(cls, name, storage, metadata): + return cls() +``` + +5. Call `pl.register_extension_type(converter.extension_name, _PolarsExtType)`. If Polars raises + `ValueError` (name already registered externally) — re-raise as `ValueError` with the same + explanation as the PyArrow case. +6. On success: `_POLARS_REGISTRY[name] = (pl_storage, metadata_str)`. + +Note: `pl.BaseExtension` and `pl.register_extension_type` are marked `@unstable()` in Polars. +The `polars>=1.36.0` constraint is a forward commitment; if the API changes, `registry.py` is +the only place to update. + +--- + +## `__init__.py` + +```python +from .registry import ExtensionTypeRegistry + +extension_type_registry = ExtensionTypeRegistry() +# PLT-1656 adds: extension_type_registry.register(), etc. + +__all__ = ["ExtensionTypeRegistry", "extension_type_registry"] +``` + +The module-level `extension_type_registry` is the process default. It is not yet referenced by +`DataContext` (that wiring is PLT-1660). + +--- + +## `pyproject.toml` + +```toml +# Before +"polars>=1.31.0", + +# After +"polars>=1.36.0", +``` + +Polars 1.36.0 is the first release that exports `pl.BaseExtension` and +`pl.register_extension_type`. The currently installed version in CI is 1.41.2. + +--- + +## Error Handling + +| Situation | Behaviour | +|---|---| +| Duplicate `extension_name` in `register()` (same `ExtensionTypeRegistry` instance) | `ValueError` with the offending name | +| PA/Polars name in shadow dict, same params | Idempotent — return silently (safe for module reload and test-suite reuse) | +| PA/Polars name in shadow dict, different params | `ValueError` showing existing vs. attempted `storage_type` and `metadata` | +| PA/Polars name NOT in shadow dict but already in global registry (external registration) | `ValueError` — raised deliberately because we cannot guarantee the externally registered type maps to the same Python class and underlying storage type; silently proceeding risks data corruption or misrouted conversions at read time | +| `get_converter_for_name` / `get_converter_for_python_type` miss | Returns `None` | +| Non-`ExtensionTypeConverter` passed to `register()` | `beartype` raises `BeartypeCallHintParamViolation` at the call site | + +--- + +## Tests + +File: `tests/test_extension_types/test_registry.py` + +A `_StubConverter` factory (similar to the one in `test_protocols.py`) creates minimal +conforming `ExtensionTypeConverter` instances with `pa.large_utf8()` as `storage_type`. Each +test that touches the process-global PA/Polars registries uses a unique `extension_name` to +avoid cross-test interference (since those globals persist for the process lifetime). + +| Test | What it verifies | +|---|---| +| `test_register_stores_converter` | `get_converter_for_name` returns the converter after `register()` | +| `test_register_populates_arrow_registry` | After `register()`, attempting to re-register the same name with PyArrow raises `pa.lib.ArrowKeyError` (proving it is registered) | +| `test_register_populates_polars_registry` | After `register()`, `pl.from_arrow(pa.array([...], type=ext_type_instance)).dtype` is a `pl.BaseExtension` instance | +| `test_register_duplicate_raises` | Second `register()` on the same registry instance with same `extension_name` → `ValueError` | +| `test_register_global_collision_same_params` | Fresh registry instance registers same name+params as a previous registry → idempotent (no error) | +| `test_register_global_collision_different_params` | Fresh registry instance registers same name but different `storage_type` → `ValueError` with both old and new params shown | +| `test_get_converter_for_name_miss` | Unknown name returns `None` | +| `test_get_converter_for_python_type_exact` | Exact type lookup returns converter | +| `test_get_converter_for_python_type_subclass` | Subclass of registered type returns converter | +| `test_get_converter_for_python_type_miss` | Unrelated type returns `None` | +| `test_has_extension_name` | Returns `True` after register, `False` before | +| `test_has_python_type` | Returns `True` after register, `False` before | +| `test_list_extension_names` | Returns correct list of registered names | +| `test_list_python_types` | Returns correct list of registered types | +| `test_python_class_round_trip` | A concrete Python class (e.g., a `Color` wrapper around a hex string) is serialised to an Arrow extension array via `converter.python_to_storage`, then deserialised back via `converter.storage_to_python`; the recovered objects equal the originals. Exercises the full converter contract end-to-end. | +| `test_arrow_polars_round_trip` | PA ext array → `pl.from_arrow` → `to_arrow()` preserves extension type and values | +| `test_parquet_round_trip` | PA ext array written to Parquet, read back via `pq.read_table` — extension type restored, `storage_to_python` recovers original Python objects | +| `test_extension_type_registry_module_instance` | `extension_types.extension_type_registry` is an `ExtensionTypeRegistry` instance and starts empty | + +--- + +## Out of Scope + +- Registering built-in converters (`Path`, `UPath`, `UUID`) — that is PLT-1656. +- Wiring `extension_type_registry` into `DataContext` — that is PLT-1660. +- Schema analysis helpers (finding extension-type columns in a schema) — not needed until PLT-1660. +- Thread safety — registration is expected to happen at import time before any concurrent I/O. +- Interop with extension types registered externally by third-party libraries (e.g., GeoPandas, + GeoArrow) — tracked in PLT-1665. The current design deliberately errors on external + registrations because we cannot guarantee the same name maps to the same Python class and + storage type; a future `register_external` opt-in will require the user to supply an explicit + converter. diff --git a/superpowers/specs/2026-06-14-plt-1654-schema-walker-design.md b/superpowers/specs/2026-06-14-plt-1654-schema-walker-design.md new file mode 100644 index 00000000..d081f85b --- /dev/null +++ b/superpowers/specs/2026-06-14-plt-1654-schema-walker-design.md @@ -0,0 +1,202 @@ +# PLT-1654: Recursive Arrow Schema Walker Design + +**Date:** 2026-06-14 +**Linear issue:** PLT-1654 +**Status:** Approved + +--- + +## Overview + +Add `src/orcapod/extension_types/schema_walker.py` — a pure discovery utility that +walks an Arrow schema (or a single field) recursively and returns all extension-typed +fields found at any depth of nesting (struct, list, map, etc.). + +This is the third piece of the `extension_types/` subpackage, sitting between +`registry.py` (PLT-1653) and the database peek-schema helper (PLT-1655). It produces the +`(extension_name, extension_metadata, storage_type)` information that PLT-1655 feeds into +`ExtensionTypeRegistry` at read time. + +**Strictly additive.** No existing code is modified. This aligns with the project-wide +parallel-build strategy: old semantic type code is untouched until PLT-1660 (the hard cut). + +--- + +## Goals & Success Criteria + +- `walk_schema(schema)` returns all extension types found in a `pa.Schema` at any depth, + deduplicated by `(extension_name, extension_metadata)`. +- `walk_field(field)` does the same for a single `pa.Field`. +- Both channels are handled: in-memory `pa.ExtensionType` instances + (`isinstance(field.type, pa.ExtensionType)` — no global registration required) and + field-metadata types (raw `ARROW:extension:name` field metadata after a Parquet/IPC + round-trip). +- All container nesting cases work: top-level column, list value, struct field, map + key/value, and arbitrary combinations thereof. +- Empty bytes `b""` from `__arrow_ext_serialize__()` is normalised to `None` so callers + never see an empty-bytes sentinel. +- No registration triggered — purely inspection. +- Works on `DeltaTable.schema().to_arrow()` output. + +--- + +## Scope & Boundaries + +In scope: +- New `src/orcapod/extension_types/schema_walker.py` +- Additive exports in `src/orcapod/extension_types/__init__.py` +- New `tests/test_extension_types/test_schema_walker.py` + +Out of scope: +- Database read path changes (PLT-1655) +- Built-in converter registrations (PLT-1656) +- Any modification to existing `semantic_types/` code +- Thread safety (registration is import-time, before concurrent I/O) + +--- + +## Architecture + +### File map + +| File | Change | +|---|---| +| `src/orcapod/extension_types/schema_walker.py` | **New** | +| `src/orcapod/extension_types/__init__.py` | Additive — new exports appended | +| `tests/test_extension_types/test_schema_walker.py` | **New** | + +No other files are touched. + +--- + +## `schema_walker.py` + +### `ExtensionTypeInfo` data container + +```python +@dataclasses.dataclass(frozen=True) +class ExtensionTypeInfo: + extension_name: str + extension_metadata: bytes | None + storage_type: pa.DataType +``` + +A frozen dataclass (not a NamedTuple): immutable, hashable, attribute access only. +`b""` is normalised to `None` at construction time — no caller ever sees an +empty-bytes metadata value. + +### Public API + +```python +def walk_schema(schema: pa.Schema) -> list[ExtensionTypeInfo]: ... +def walk_field(field: pa.Field) -> list[ExtensionTypeInfo]: ... +``` + +Both return a deduplicated list in depth-first, first-seen order. The deduplication key +is `(extension_name, extension_metadata)`. When the same pair appears in multiple +columns, only the first occurrence (and its `storage_type`) is kept. + +### Internal helpers + +**`_collect(field, seen, results)`** — the recursive core. Mutates `seen` (a +`set[tuple[str, bytes | None]]`) and `results` (a `list[ExtensionTypeInfo]`) in place: + +1. Call `_detect_extension(field)`. If it returns an `ExtensionTypeInfo`: + - Add to `results` if `(extension_name, extension_metadata)` is not in `seen`. + - Update `seen`. + - **Return immediately** — do not descend into the storage type. +2. Otherwise inspect `field.type` and recurse: + - `is_struct` → `t.field(i)` for each `i` in `range(t.num_fields)` + - `is_list` / `is_large_list` / `is_fixed_size_list` / `is_list_view` / + `is_large_list_view` → `t.value_field` + - `is_map` → `t.key_field` and `t.item_field` (via `getattr`; available in + PyArrow ≥ 14, project requires ≥ 20) + - Primitives and unrecognised types → no-op + +**`_detect_extension(field) -> ExtensionTypeInfo | None`** — detects whether a field +carries extension type information via either channel: + +**Channel 1 — In-memory ExtensionType** (`isinstance(field.type, pa.ExtensionType)` is True): + +A `pa.ExtensionType` instance is attached to the field. No global registry registration +is required — this branch fires for any `pa.ExtensionType` subclass instance, whether +registered or not. The type object carries everything: + +```python +ext_type = field.type +raw_meta = ext_type.__arrow_ext_serialize__() +return ExtensionTypeInfo( + extension_name=ext_type.extension_name, + extension_metadata=raw_meta or None, + storage_type=ext_type.storage_type, +) +``` + +**Channel 2 — Unregistered** (`field.metadata` contains `b"ARROW:extension:name"`): + +The type was registered elsewhere and survived a Parquet/IPC round-trip. The raw +`field.type` is the storage type; name and metadata are in the field's Arrow metadata: + +```python +name = field.metadata[b"ARROW:extension:name"].decode("utf-8") +raw_meta = field.metadata.get(b"ARROW:extension:metadata") +return ExtensionTypeInfo( + extension_name=name, + extension_metadata=raw_meta or None, + storage_type=field.type, +) +``` + +Channel 1 is checked first. `None` is returned if neither applies. + +--- + +## `__init__.py` additions + +```python +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +__all__ = [ + "ExtensionTypeConverter", + "ExtensionTypeRegistry", + "default_extension_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +--- + +## Tests — `tests/test_extension_types/test_schema_walker.py` + +Uses `_unique_name()` and `_make_reg_field()` / `_make_unreg_field()` helpers. +In-memory-channel tests construct `pa.ExtensionType` subclass instances directly via +`type()` and attach them to a `pa.Field` — **no global registration** is performed. +Field-metadata-channel tests construct `pa.Field` objects with explicit +`metadata={b"ARROW:extension:name": ..., b"ARROW:extension:metadata": ...}`. + +| Test | What it covers | +|---|---| +| `test_empty_schema` | Empty schema → `[]` | +| `test_no_extension_types` | Schema with only primitives → `[]` | +| `test_top_level_registered` | Registered ext type as top-level column | +| `test_top_level_unregistered` | Unregistered ext type via raw field metadata | +| `test_list_of_registered` | Registered ext type as list value field | +| `test_list_of_unregistered` | Unregistered ext type as list value field | +| `test_struct_containing_registered` | Registered ext type inside a struct field | +| `test_struct_containing_unregistered` | Unregistered ext type inside a struct field | +| `test_nested_list_struct` | `list>>` — arbitrary nesting | +| `test_deduplication` | Same `(name, metadata)` in two columns → one result | +| `test_empty_metadata_normalised_to_none` | `b""` from `__arrow_ext_serialize__` → `None` | +| `test_walk_field` | `walk_field` on a single field returns correct result | +| `test_map_type` | Extension type as map item value | + +--- + +## PLT-1660 cleanup items (deferred) + +- Remove `SemanticTypeRegistry.find_semantic_fields_in_schema` (shape-based — replaced by + `walk_schema`). +- Remove `SemanticTypeRegistry.get_semantic_field_info` (shape-based — same fate). diff --git a/superpowers/specs/2026-06-14-plt-1655-database-hooks-design.md b/superpowers/specs/2026-06-14-plt-1655-database-hooks-design.md new file mode 100644 index 00000000..b4b680a9 --- /dev/null +++ b/superpowers/specs/2026-06-14-plt-1655-database-hooks-design.md @@ -0,0 +1,387 @@ +# PLT-1655: Peek-Schema → Register → Read Pattern with Per-Process Cache + +**Date:** 2026-06-14 +**Linear issue:** PLT-1655 +**Status:** Implemented + +> **Implementation note (2026-06-15):** During implementation the design was +> refined: rather than wiring hooks directly into `DeltaTableDatabase` and +> `ConnectorArrowDatabase`, a dedicated `ExtensionAwareDatabase` wrapper was +> introduced. Database classes remain pure storage; the wrapper applies +> `register_discovered_extensions` + `apply_extension_types` on every read. +> The table below documents the actual shipped API. + +--- + +## Overview + +Two complementary utilities in `extension_types/database_hooks.py` handle +extension type awareness at database read time: + +1. **`register_discovered_extensions(registry, schema)`** — walks the schema and + registers any unknown extension types via the registry's factory dispatch. No-op + when `registry` is `None` or the schema contains no extension types. Repeated reads + are cheap: already-registered types are detected and skipped inside the registry. + +2. **`apply_extension_types(table, registry)`** — Arrow preserves + `ARROW:extension:name` / `ARROW:extension:metadata` field metadata even when an + extension type is not registered at read time (columns load as plain storage types). + After registration, this function re-wraps those storage columns into their correct + Arrow extension types using `pa.ExtensionArray.from_storage` per chunk — zero-copy + and no data movement. Struct columns are handled recursively. + +Callers use these through the **`ExtensionAwareDatabase`** wrapper, which applies +both steps on every read result automatically. + +--- + +## Goals & Success Criteria + +* `register_discovered_extensions(registry, schema)` in + `extension_types/database_hooks.py` correctly discovers all extension type fields + at any nesting depth and delegates registration to the registry. +* `apply_extension_types(table, registry)` correctly re-wraps storage columns into + extension types per-chunk without data copies; preserves schema-level metadata; + handles struct columns recursively; skips structs with no extension children. +* When the schema contains no extension types both calls are no-ops; existing tests + continue to pass unchanged. +* For each extension type found in the schema, `ensure_extension_type` applies checks + in this order: + 1. **Already registered** (by Arrow extension name) → silent no-op. This is the + common fast path for all types after first registration. + Metadata value is irrelevant — `None` metadata on an already-registered type + never causes an error. + 2. **Not registered, non-`None` metadata, matching factory** → factory constructs a + `LogicalTypeProtocol` and it is registered in PyArrow, Polars, and the registry + before the table is returned. + 3. **Not registered, non-`None` metadata, no matching factory** → clear `ValueError` + naming the extension name and metadata tag, with a pointer to + `register_logical_type_factory`. + 4. **Not registered, `None` metadata** → clear `ValueError` explaining that types + without a category tag cannot be auto-registered via a factory and must be + pre-registered explicitly via `registry.register_logical_type(logical_type)`. +* `ExtensionAwareDatabase` correctly wraps any `ArrowDatabaseProtocol` backend, + applies both steps on every read, and passes writes through unchanged. +* Sufficient `DEBUG`-level logging throughout so that extension type discovery, + registration decisions, and factory dispatch are observable without code changes. + +--- + +## Scope & Boundaries + +In scope: +* New `src/orcapod/extension_types/database_hooks.py` + — `register_discovered_extensions` and `apply_extension_types` +* New `src/orcapod/databases/extension_aware_database.py` — `ExtensionAwareDatabase` +* New `LogicalTypeFactoryProtocol` Protocol in + `src/orcapod/extension_types/protocols.py` +* New methods on `LogicalTypeRegistry` (`registry.py`): + `register_logical_type_factory` and `ensure_extension_type` +* Additive exports in `src/orcapod/extension_types/__init__.py` +* Tests for all new code + +Out of scope (database classes are pure storage, unchanged): +* `src/orcapod/databases/delta_lake_databases.py` — no extension type hooks +* `src/orcapod/databases/connector_arrow_database.py` — no extension type hooks +* Implementing concrete `LogicalTypeFactoryProtocol` instances (PLT-1657 + `dataclass_handler`, PLT-1658 `picklable_handler`) +* Built-in logical type registrations (PLT-1656) +* Thread safety of the global registry dicts (deferred) +* Any change to `semantic_types/` (old system, untouched until PLT-1660) + +--- + +## Architecture + +### File map + +| File | Change | +|---|---| +| `src/orcapod/extension_types/protocols.py` | Add `LogicalTypeFactoryProtocol` Protocol | +| `src/orcapod/extension_types/registry.py` | Add `register_logical_type_factory`, `ensure_extension_type` | +| `src/orcapod/extension_types/database_hooks.py` | **New** — `register_discovered_extensions`, `apply_extension_types` | +| `src/orcapod/extension_types/__init__.py` | Additive exports | +| `src/orcapod/databases/extension_aware_database.py` | **New** — `ExtensionAwareDatabase` wrapper | +| `tests/test_extension_types/test_database_hooks.py` | **New** | +| `tests/test_databases/test_extension_aware_database.py` | **New** | + +--- + +## `LogicalTypeFactoryProtocol` Protocol + +**Location:** `src/orcapod/extension_types/protocols.py` + +`LogicalTypeFactoryProtocol` is a pure factory. Given an Arrow extension name, its +storage type, and the full parsed metadata dict, it constructs a fully-formed +`LogicalTypeProtocol` instance ready to pass to +`LogicalTypeRegistry.register_logical_type()`. + +The `category` string that routes to this factory is declared by the caller at +registration time — the factory itself has no knowledge of its dispatch key, but +receives the full metadata dict so it can read additional hints (e.g. version, +serialisation format) beyond just the category. + +### Metadata format + +`extension_metadata` bytes are expected to be **UTF-8-encoded JSON** with at least a +`"category"` key: + +```json +{"category": "Dataclass"} +{"category": "Pickle", "protocol": 5} +{"category": "Pydantic", "pydantic_version": 2} +``` + +The `category` value is the factory dispatch key. All other fields are passed through +to the factory as-is and interpreted by the factory implementation. + +### Protocol definition + +```python +class LogicalTypeFactoryProtocol(Protocol): + def create_logical_type( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict, + ) -> LogicalTypeProtocol: + """Construct a ``LogicalTypeProtocol`` for the given Arrow extension name. + + Args: + arrow_extension_name: The Arrow extension type name extracted from the + schema (i.e. the value of ``ARROW:extension:name`` field metadata). + storage_type: The underlying Arrow storage type for this extension field. + metadata: The full parsed JSON metadata dict. Always contains at least a + ``"category"`` key. May contain additional keys the factory uses. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready to be passed to + ``LogicalTypeRegistry.register_logical_type()``. + + Raises: + ValueError: If this factory cannot construct a logical type for the given + extension name. + """ + ... +``` + +This protocol is `@runtime_checkable`, consistent with `LogicalTypeProtocol`. + +--- + +## `LogicalTypeRegistry` additions + +**Location:** `src/orcapod/extension_types/registry.py` + +Two new methods are added to `LogicalTypeRegistry`. The existing public API is +unchanged. + +### `register_logical_type_factory` + +```python +def register_logical_type_factory( + self, + category: str, + factory: LogicalTypeFactoryProtocol, +) -> None: + """Register a factory for the given metadata category string. + + When ``ensure_extension_type`` encounters an Arrow extension type whose + ``extension_metadata`` JSON contains ``{"category": "", ...}``, + it calls ``factory.create_logical_type(arrow_extension_name, storage_type, + metadata_dict)`` to construct the logical type and then registers it. + + Args: + category: The ``"category"`` value from the extension metadata JSON. + factory: A ``LogicalTypeFactoryProtocol`` instance responsible for + constructing logical types for this category. + + Raises: + ValueError: If ``category`` is already registered to a different factory. + """ +``` + +### `ensure_extension_type` + +```python +def ensure_extension_type( + self, + arrow_extension_name: str, + extension_metadata: bytes | None, + storage_type: pa.DataType, +) -> None: + """Ensure the Arrow extension type identified by ``arrow_extension_name`` + is registered as a ``LogicalTypeProtocol``. + + This is the single entry point called by ``register_discovered_extensions`` + in ``database_hooks``. The registry owns all dispatch logic: + + 1. Already registered → return immediately (per-process cache hit). + 2. ``extension_metadata`` is ``None`` → ``ValueError``. + 3. Decode metadata as UTF-8 JSON → ``ValueError`` on failure. + 4. Extract ``"category"`` key → ``ValueError`` if absent. + 5. Look up factory by category → ``ValueError`` if not found. + 6. Call factory.create_logical_type(...) → ``LogicalTypeProtocol``. + 7. Call self.register_logical_type(logical_type). + """ +``` + +Error messages direct callers to use `registry.register_logical_type(logical_type)` or +`registry.register_logical_type_factory(category, factory)` on the registry instance +used for reads — no references to any module-level singleton. + +--- + +## `database_hooks.py` + +**Location:** `src/orcapod/extension_types/database_hooks.py` + +### `register_discovered_extensions` + +```python +def register_discovered_extensions( + registry: LogicalTypeRegistry | None, + schema: pa.Schema, +) -> None: + """Register any extension types found in ``schema`` that are not yet known. + + Walks ``schema`` recursively; for each discovered type calls + ``registry.ensure_extension_type``. No-op when ``registry`` is ``None`` + or the schema has no extension types. + """ +``` + +This function is intentionally stateless and contains no dispatch logic. + +### `apply_extension_types` + +```python +def apply_extension_types( + table: pa.Table, + registry: LogicalTypeRegistry, +) -> pa.Table: + """Re-wrap *table* columns into their registered Arrow extension types. + + Arrow preserves ``ARROW:extension:name`` / ``ARROW:extension:metadata`` + field metadata even when an extension type was not registered at read time. + Once registered, this function reconstructs extension-typed columns from + storage using ``pa.ExtensionArray.from_storage`` per chunk (zero-copy). + Struct columns are handled recursively; structs with no extension children + are skipped entirely. + + Returns the original table unchanged when no columns need re-wrapping. + Schema-level metadata is preserved on the rebuilt table. + """ +``` + +--- + +## `ExtensionAwareDatabase` wrapper + +**Location:** `src/orcapod/databases/extension_aware_database.py` + +```python +class ExtensionAwareDatabase: + """ArrowDatabaseProtocol wrapper that auto-registers and applies extension types. + + Takes any ArrowDatabaseProtocol backend and a LogicalTypeRegistry. Every + read result flows through: + 1. register_discovered_extensions(registry, table.schema) + 2. apply_extension_types(table, registry) + + Write methods and structural methods (at, flush, base_path) delegate + directly to the wrapped database without modification. + """ + + def __init__(self, db: ArrowDatabaseProtocol, registry: LogicalTypeRegistry) -> None: ... + def at(self, *path_components: str) -> ExtensionAwareDatabase: ... + # All ArrowDatabaseProtocol read/write methods delegated +``` + +Database classes (`DeltaTableDatabase`, `ConnectorArrowDatabase`) remain pure +storage with no extension type awareness. Callers that need extension type handling +wrap their database explicitly: + +```python +db = DeltaTableDatabase("/path/to/store") +ext_db = ExtensionAwareDatabase(db, registry=data_context.logical_type_registry) +table = ext_db.get_all_records(("results", "my_fn")) +# table columns have proper extension types applied +``` + +--- + +## Per-process cache design + +The per-process cache is `LogicalTypeRegistry._by_arrow_name`. The first call to +`ensure_extension_type` for a given `arrow_extension_name` performs factory dispatch +and registers the `LogicalTypeProtocol`. Every subsequent call for the same name hits +the `get_by_arrow_extension_name` check and returns immediately. + +Because the registry instance lives for the process lifetime (typically as +`data_context.logical_type_registry`), this provides exactly the per-process caching +semantics described in PLT-1655. No separate `set` is needed in `database_hooks.py` +— the registry is the cache. + +--- + +## Logging summary + +| Location | Level | Message | +|---|---|---| +| `database_hooks.register_discovered_extensions` | DEBUG | No extension types found in schema | +| `database_hooks.register_discovered_extensions` | DEBUG | N extension types found, lists names | +| `database_hooks.apply_extension_types` | DEBUG | Wrapped column X as extension type Y | +| `registry.ensure_extension_type` | DEBUG | Already registered — skipping | +| `registry.ensure_extension_type` | DEBUG | Not registered — dispatching to category factory | +| `registry.ensure_extension_type` | DEBUG | Successfully registered via factory for category | +| `registry.register_logical_type_factory` | DEBUG | Factory registered for category string | + +All messages use `%r`/`%s` lazy formatting (no f-strings in log calls). + +--- + +## Tests + +**`tests/test_extension_types/test_database_hooks.py`** + +| Test | What it covers | +|---|---| +| `test_no_extension_types_is_noop` | Schema with only primitives — `register_discovered_extensions` returns without touching registry | +| `test_known_type_is_registered` | Schema with one extension type whose factory is registered — logical type registered | +| `test_already_registered_is_skipped` | Call `register_discovered_extensions` twice — second call is no-op | +| `test_unknown_metadata_raises` | Unregistered extension type with valid JSON metadata but no matching factory — `ValueError` | +| `test_metadata_not_json_raises` | Unregistered type with non-JSON metadata — `ValueError` with raw bytes | +| `test_metadata_json_missing_category_raises` | Valid JSON but no `"category"` key — `ValueError` | +| `test_none_metadata_not_registered_raises` | `None` metadata on unregistered type — `ValueError` | +| `test_none_metadata_already_registered_noop` | `None` metadata on already-registered type — silent no-op | +| `test_nested_extension_type` | Extension type inside a struct column — walker descends and registers it | +| `test_noop_when_no_extension_metadata` | `apply_extension_types`: plain-types table returned as-is (same object) | +| `test_wraps_storage_column_into_extension_type` | `apply_extension_types`: storage column with metadata re-wrapped | +| `test_zero_copy_single_chunk` | `apply_extension_types`: from_storage shares the underlying buffer | +| `test_zero_copy_multiple_chunks` | `apply_extension_types`: multi-chunk columns wrapped per-chunk | +| `test_already_extension_type_passthrough` | Column already extension-typed returned as-is | +| `test_unregistered_extension_metadata_left_as_storage` | Unregistered ext metadata column stays as storage type | +| `test_nested_struct_extension_type` | Extension type inside struct child field reconstructed recursively | +| `test_mixed_columns_only_ext_columns_changed` | Plain columns untouched when an extension column is processed | + +**`tests/test_databases/test_extension_aware_database.py`** + +| Test | What it covers | +|---|---| +| `test_get_all_records_applies_extension_types` | Wrapper applies extension types on `get_all_records` | +| `test_get_record_by_id_applies_extension_types` | Wrapper applies extension types on `get_record_by_id` | +| `test_get_records_by_ids_applies_extension_types` | Wrapper applies extension types on `get_records_by_ids` | +| `test_get_all_records_returns_none_when_no_records` | Returns `None` when inner DB has no records | +| `test_write_methods_passthrough` | `add_record` / `add_records` write correctly through wrapper | +| `test_at_returns_extension_aware_database` | `at()` returns `ExtensionAwareDatabase` with same registry | +| `test_base_path_delegates_to_inner` | `base_path` reflects inner database's `base_path` | +| `test_plain_table_passthrough_unchanged` | Tables with no extension metadata returned as-is | + +--- + +## Dependencies + +* PLT-1653 (`ExtensionTypeRegistry` → `LogicalTypeRegistry`) — **merged** +* PLT-1654 (`schema_walker`) — **merged** +* PLT-1668 (`LogicalTypeProtocol` / `LogicalTypeRegistry` redesign) — **merged** (unblocked) diff --git a/superpowers/specs/2026-06-14-plt-1656-builtin-logical-types-design.md b/superpowers/specs/2026-06-14-plt-1656-builtin-logical-types-design.md new file mode 100644 index 00000000..7d4843f6 --- /dev/null +++ b/superpowers/specs/2026-06-14-plt-1656-builtin-logical-types-design.md @@ -0,0 +1,308 @@ +# PLT-1656: Built-in LogicalType Implementations (Path, UPath, UUID) + +**Date:** 2026-06-14 +**Issue:** PLT-1656 +**Depends on:** PLT-1668 (LogicalType protocol + LogicalTypeRegistry — completed) + +--- + +## Overview + +Implement the three built-in `LogicalType` instances (`LogicalPath`, +`LogicalUPath`, `LogicalUUID`) in a new module +`src/orcapod/extension_types/builtin_logical_types.py`. + +Wire the default registry into `DataContext` via `v0.1.json` using the existing +`parse_objectspec()` JSON object spec mechanism — exactly as `semantic_registry`, +`type_converter`, and the other default objects are built. The primary access path +for the default registry is `get_default_context().logical_type_registry`, with a +`get_default_logical_type_registry()` convenience function added to `contexts`. + +These are the first concrete implementations of the `LogicalType` protocol +introduced by PLT-1668. The naming convention is `LogicalXXX` (no "Type" suffix): +`LogicalType` is the abstract protocol; `LogicalPath`, `LogicalUPath`, `LogicalUUID` +are the concrete descriptors. The old `PythonPathStructConverter`, +`UPathStructConverter`, and `UUIDStructConverter` in +`semantic_types/semantic_struct_converters.py` remain untouched until PLT-1660 +(hard cut). + +--- + +## New file: `src/orcapod/extension_types/builtin_logical_types.py` + +### `LogicalPath` + +| Property / Method | Value | +|---|---| +| `logical_type_name` | `"pathlib.Path"` | +| `python_type` | `pathlib.Path` | +| Arrow extension name | `"pathlib.Path"` (custom — created via `make_arrow_extension_type`) | +| Arrow storage type | `pa.large_string()` | +| Arrow extension metadata | `b"orcapod.builtin"` | +| `python_to_storage(path)` | `str(path)` | +| `storage_to_python(s)` | `Path(s)` | + +`get_arrow_extension_type()` uses +`make_arrow_extension_type("pathlib.Path", pa.large_string(), b"orcapod.builtin")` +to obtain the class (called once), then returns a cached instance. + +`get_polars_extension_type()` uses +`make_polars_extension_type("pathlib.Path", pa.large_string(), "orcapod.builtin")` +similarly. + +### `LogicalUPath` + +Identical structure to `LogicalPath` with: + +| Property / Method | Value | +|---|---| +| `logical_type_name` | `"upath.UPath"` | +| `python_type` | `upath.UPath` | +| Arrow extension name | `"upath.UPath"` | +| `python_to_storage(upath)` | `str(upath)` | +| `storage_to_python(s)` | `UPath(s)` | + +### `LogicalUUID` + +| Property / Method | Value | +|---|---| +| `logical_type_name` | `"uuid.UUID"` | +| `python_type` | `uuid.UUID` | +| Arrow extension name | `"uuid.UUID"` (custom — created via `make_arrow_extension_type`) | +| Arrow storage type | `pa.large_binary()` | +| Arrow extension metadata | `None` (empty bytes) | +| `python_to_storage(uuid_val)` | `uuid_val.bytes` | +| `storage_to_python(bytes_val)` | `uuid.UUID(bytes=bytes(bytes_val))` | + +`get_arrow_extension_type()` uses +`make_arrow_extension_type("uuid.UUID", pa.large_binary())`, following the +same pattern as `LogicalPath` and `LogicalUPath`. `logical_type_name` and the +Arrow extension name are both `"uuid.UUID"`. + +`pa.large_binary()` is used rather than `pa.binary(16)` (fixed-size) because +Polars maps fixed-size binary to variable-length on the round-trip, which +would conflict with the deserializer's storage-type check. + +PyArrow's built-in `pa.uuid()` (`"arrow.uuid"`) is intentionally **not** used: +it is a C++ built-in type (`UuidType(BaseExtensionType)`) that Polars has +hardcoded in its Rust layer at startup and cannot be overridden from Python, +causing Arrow → Polars → Arrow round-trips to silently strip the extension. + +`get_polars_extension_type()` uses +`make_polars_extension_type("uuid.UUID", pa.large_binary())`. + +### Caching strategy + +Each class caches its Arrow and Polars extension type instances as class-level +attributes to avoid re-creating dynamic subclasses on every `get_*` call: + +```python +class LogicalPath: + _arrow_ext_class = make_arrow_extension_type("pathlib.Path", pa.large_string()) + _arrow_ext: pa.ExtensionType | None = None + + def get_arrow_extension_type(self) -> pa.ExtensionType: + if LogicalPath._arrow_ext is None: + LogicalPath._arrow_ext = LogicalPath._arrow_ext_class() + return LogicalPath._arrow_ext +``` + +Imports inside `builtin_logical_types.py` must use direct submodule paths +(e.g. `from orcapod.extension_types.registry import make_arrow_extension_type`), +not the package `__init__` (`from orcapod.extension_types import ...`), to avoid +a circular import when the context system loads this module. + +--- + +## New helper: `make_polars_extension_type` in `registry.py` + +Add alongside the existing `make_arrow_extension_type`: + +```python +def make_polars_extension_type( + extension_name: str, + arrow_storage_type: pa.DataType, + metadata: str | None = None, +) -> type[pl.BaseExtension]: + """Synthesise and return a ``pl.BaseExtension`` subclass. + + Derives the Polars storage dtype from *arrow_storage_type* via + ``pl.from_arrow``. Returns the *class*; callers instantiate it inside + ``get_polars_extension_type()``. + """ +``` + +Polars dtype is computed once via +`pl.from_arrow(pa.array([], type=arrow_storage_type)).dtype` and captured +in the closure, mirroring the `make_arrow_extension_type` pattern. + +Export `make_polars_extension_type` from `__init__.py` alongside +`make_arrow_extension_type`. + +--- + +## `LogicalTypeRegistry.__init__` — add `logical_types` parameter + +Small backward-compatible addition to `registry.py` so that +`parse_objectspec()` can populate the registry via `_config`: + +```python +def __init__(self, logical_types: list[LogicalType] | None = None) -> None: + self._by_logical_name: dict[str, LogicalType] = {} + self._by_arrow_name: dict[str, LogicalType] = {} + self._by_python_type: dict[type, LogicalType] = {} + for lt in (logical_types or []): + self.register(lt) +``` + +Same pattern as `SemanticTypeRegistry`'s `converters` constructor argument. + +--- + +## `DataContext` — add `logical_type_registry` field + +In `src/orcapod/contexts/core.py`, add field to the `DataContext` dataclass: + +```python +from orcapod.extension_types.registry import LogicalTypeRegistry + +@dataclass +class DataContext: + ... + logical_type_registry: LogicalTypeRegistry +``` + +--- + +## `v0.1.json` — add `logical_type_registry` entry + +Add before the `"metadata"` key: + +```json +"logical_type_registry": { + "_class": "orcapod.extension_types.registry.LogicalTypeRegistry", + "_config": { + "logical_types": [ + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUPath", + "_config": {} + }, + { + "_class": "orcapod.extension_types.builtin_logical_types.LogicalUUID", + "_config": {} + } + ] + } +} +``` + +--- + +## `contexts/__init__.py` — add convenience accessor + +Add alongside `get_default_type_converter()`: + +```python +def get_default_logical_type_registry() -> LogicalTypeRegistry: + """Get the default logical type registry. + + Returns: + ``LogicalTypeRegistry`` instance from the default context. + """ + return get_default_context().logical_type_registry +``` + +Add to `__all__`. + +--- + +## `context_schema.json` — add `logical_type_registry` + +Add `"logical_type_registry"` to the required/allowed fields in +`src/orcapod/contexts/data/schemas/context_schema.json`. + +--- + +## `extension_types/__init__.py` — remove standalone default registry + +**Remove** the line `default_logical_type_registry = LogicalTypeRegistry()`. + +The standard access paths are now: +- `get_default_context().logical_type_registry` +- `get_default_logical_type_registry()` (from `orcapod.contexts`) + +Removing the module-level variable avoids a circular import: if `__init__.py` +called `get_default_context()` at import time, it would force-eager-load all +context components (file hasher, semantic registry, arrow hasher, etc.) whenever +`orcapod.extension_types` is imported. + +Update `__all__` accordingly. + +--- + +## Tests: `tests/test_extension_types/test_builtin_logical_types.py` + +### Protocol conformance +- `isinstance(LogicalPath(), LogicalType)` → `True` (and `LogicalUPath`, `LogicalUUID`) + +### Property values +- `logical_type_name`, `python_type` correct for each class +- `get_arrow_extension_type().extension_name` returns expected Arrow ext name +- UUID: `get_arrow_extension_type().extension_name == "arrow.uuid"` (not `"uuid.UUID"`) + +### Conversion round-trips +- `Path`: `storage_to_python(python_to_storage(Path("/tmp/foo"))) == Path("/tmp/foo")` +- `UPath`: `storage_to_python(python_to_storage(UPath("s3://bucket/key"))) == UPath("s3://bucket/key")` +- `UUID`: `storage_to_python(python_to_storage(some_uuid)) == some_uuid` + +### Default context registration +After `from orcapod.contexts import get_default_context`: +- `get_default_context().logical_type_registry.get_by_logical_name("pathlib.Path")` → `LogicalPath` +- `get_default_context().logical_type_registry.get_by_python_type(Path)` → `LogicalPath` +- `get_default_context().logical_type_registry.get_by_arrow_extension_name("pathlib.Path")` → `LogicalPath` +- Same pattern for UPath +- `get_default_context().logical_type_registry.get_by_logical_name("uuid.UUID")` → `LogicalUUID` +- `get_default_context().logical_type_registry.get_by_arrow_extension_name("arrow.uuid")` → `LogicalUUID` + +### Pre-existing Arrow type tolerance +- Registering `LogicalUUID` succeeds even though `pa.uuid()` (`"arrow.uuid"`) is already registered in PyArrow + +### Idempotence +- Calling `get_default_context()` twice returns the same `LogicalTypeRegistry` instance (context caching) + +--- + +## Summary of files changed + +| File | Change | +|---|---| +| `src/orcapod/extension_types/builtin_logical_types.py` | **New** — three `LogicalType` implementations | +| `src/orcapod/extension_types/registry.py` | Add `make_polars_extension_type` helper; add `logical_types` param to `LogicalTypeRegistry.__init__` | +| `src/orcapod/extension_types/__init__.py` | Remove `default_logical_type_registry`; export `make_polars_extension_type` | +| `src/orcapod/contexts/core.py` | Add `logical_type_registry: LogicalTypeRegistry` to `DataContext` | +| `src/orcapod/contexts/data/v0.1.json` | Add `logical_type_registry` entry | +| `src/orcapod/contexts/data/schemas/context_schema.json` | Add `logical_type_registry` to schema | +| `src/orcapod/contexts/__init__.py` | Add `get_default_logical_type_registry()` | +| `tests/test_extension_types/test_builtin_logical_types.py` | **New** — tests | + +--- + +## Scope boundaries + +**In scope (this issue):** +- `builtin_logical_types.py` with three `LogicalType` implementations +- `make_polars_extension_type` helper in `registry.py` +- `logical_types` constructor param in `LogicalTypeRegistry` +- `DataContext.logical_type_registry` field + `v0.1.json` entry + schema update +- `get_default_logical_type_registry()` in `contexts` +- Tests in `test_builtin_logical_types.py` + +**Out of scope (deferred to PLT-1660):** +- Deleting `PythonPathStructConverter`, `UPathStructConverter`, `UUIDStructConverter` +- Using `logical_type_registry` inside `DataContext`'s other components + (e.g. replacing `UniversalTypeConverter`'s semantic registry lookup) +- File hashing — remains exclusively in `PathContentHandler` / `UPathContentHandler` diff --git a/superpowers/specs/2026-06-14-plt-1668-logical-type-redesign.md b/superpowers/specs/2026-06-14-plt-1668-logical-type-redesign.md new file mode 100644 index 00000000..17331a56 --- /dev/null +++ b/superpowers/specs/2026-06-14-plt-1668-logical-type-redesign.md @@ -0,0 +1,271 @@ +# PLT-1668: Redesign ExtensionTypeConverter → LogicalType protocol with converter-owned extension types and three-way binding in LogicalTypeRegistry + +**Date:** 2026-06-14 +**Issue:** [PLT-1668](https://linear.app/enigma-metamorphic/issue/PLT-1668) +**Branch:** `eywalker/plt-1668-redesign-extensiontypeconverter-logicaltype-protocol-with` +**Target:** `extension-type-system` + +--- + +## Problem + +`ExtensionTypeConverter` and `ExtensionTypeRegistry` have a separation-of-concerns violation: +the registry dynamically synthesises `pa.ExtensionType` and `pl.BaseExtension` subclasses at +registration time, reading raw ingredient properties (`extension_name`, `extension_metadata`, +`storage_type`) directly off the converter. The converter supplies ingredients; the registry +manufactures the types. This is the wrong ownership model. + +It also breaks when the Arrow extension type already exists as a pre-registered type (e.g. +PyArrow's built-in `"arrow.uuid"`) because the registry always tries to create a fresh type and +errors on the resulting `ArrowKeyError`. + +--- + +## Solution + +Introduce **`LogicalType`**: a protocol where each implementation owns and returns its Arrow and +Polars extension types directly. The registry's job shrinks to storing the binding, triggering +side-effect registrations in the PA/Polars global registries, and enforcing that no two logical +types share any member of their three-way identity triplet +`(logical_type_name, arrow_ext_name, python_type)`. + +--- + +## Design + +### `LogicalType` protocol (`extension_types/protocols.py`) + +Replaces `ExtensionTypeConverter`. All six members are required. + +```python +@runtime_checkable +class LogicalType(Protocol): + @property + def logical_type_name(self) -> str: + """Unique orcapod identifier for this logical type. + + By convention the Python FQCN (e.g. ``"uuid.UUID"``), but any unique + string is valid. Does NOT need to match the Arrow extension type name. + """ + + @property + def python_type(self) -> type: + """The Python class this logical type represents.""" + + def get_arrow_extension_type(self) -> pa.ExtensionType: + """Return the Arrow extension type for this logical type. + + ``storage_type``, ``extension_name``, and serialised metadata are + encapsulated inside the returned type; they are no longer top-level + properties on ``LogicalType``. + + For custom types: create and return an instance of a new + ``pa.ExtensionType`` subclass (e.g. via ``make_arrow_extension_type``). + For pre-existing types: return the existing instance directly + (e.g. ``pa.uuid()``). + """ + + def get_polars_extension_type(self) -> pl.BaseExtension: + """Return an instance of the Polars extension type for this logical type. + + The registry calls ``type(instance)`` to obtain the class passed to + ``pl.register_extension_type``. + """ + + def python_to_storage(self, value: Any) -> Any: + """Convert a Python value to its Arrow storage representation.""" + + def storage_to_python(self, storage_value: Any) -> Any: + """Convert an Arrow storage value back to a Python object.""" +``` + +**Removed from protocol** (now encapsulated inside the extension type instances): +- `extension_name` → accessible via `get_arrow_extension_type().extension_name` +- `extension_metadata` → `get_arrow_extension_type().__arrow_ext_serialize__()` +- `storage_type` → `get_arrow_extension_type().storage_type` + +--- + +### `make_arrow_extension_type` helper (`extension_types/registry.py`) + +A module-level convenience factory for custom `LogicalType` implementations that need to +synthesise a new `pa.ExtensionType` subclass. Returns the **class** (not an instance), so +callers can instantiate it on demand and create parameterised variants in the future. + +```python +def make_arrow_extension_type( + extension_name: str, + storage_type: pa.DataType, + metadata: bytes | None = None, +) -> type[pa.ExtensionType]: + """Synthesise and return a ``pa.ExtensionType`` subclass. + + Returns the *class*, not an instance — callers instantiate it inside their + ``get_arrow_extension_type()`` implementation. Returning the class preserves + the option to create multiple instances or future parameterised variants from + the same class. + + This is a low-level building block. The full pattern for binding a Python + type to a specific Arrow/Polars representation — the extension type factory — + is the responsibility of each ``LogicalType`` implementation. See PLT-1656 + for the built-in implementations (``Path``, ``UPath``, ``UUID``). + """ +``` + +Internally uses `type()` dynamic class synthesis (the same technique previously inside +`_register_arrow_ext_type`), now surfaced as a public utility. + +**Typical usage pattern:** + +```python +_MyArrowExt = make_arrow_extension_type("my.Type", pa.large_string(), b"my.category") + +class MyLogicalType: + def get_arrow_extension_type(self) -> pa.ExtensionType: + return _MyArrowExt() +``` + +--- + +### `LogicalTypeRegistry` (`extension_types/registry.py`) + +Replaces `ExtensionTypeRegistry`. + +#### Storage + +Three per-instance dicts — no module-level shadow dicts: + +```python +_by_logical_name: dict[str, LogicalType] +_by_arrow_name: dict[str, LogicalType] # keyed by arrow_ext_type.extension_name +_by_python_type: dict[type, LogicalType] +``` + +Uniqueness is enforced per-instance. The process-global `default_logical_type_registry` +singleton provides effective process-wide uniqueness for normal use. + +#### `register(logical_type: LogicalType)` — full behaviour + +1. Derive `arrow_ext_name = logical_type.get_arrow_extension_type().extension_name` +2. Derive `py_type = logical_type.python_type` +3. **Triplet conflict check** — for each of the three keys (`logical_type_name`, + `arrow_ext_name`, `py_type`): if already bound to a *different* `LogicalType`, + raise `ValueError` naming the conflicting key and both logical type names. +4. **Idempotent check** — if all three keys are already bound to the *same* `LogicalType`, + return immediately (no-op). +5. Attempt `pa.register_extension_type(logical_type.get_arrow_extension_type())`. + If `pa.lib.ArrowKeyError` is raised (type already registered — by a prior call on + another registry instance, or by an external source such as PyArrow itself), accept + silently and continue. Validation of the pre-existing type against the expected class + is deferred to PLT-1669. +6. Derive `polars_ext_class = type(logical_type.get_polars_extension_type())`. + Attempt `pl.register_extension_type(arrow_ext_name, polars_ext_class)`. + If `ValueError` is raised (already registered), accept silently and continue. +7. Store three-way binding: + - `_by_logical_name[logical_type_name] = logical_type` + - `_by_arrow_name[arrow_ext_name] = logical_type` + - `_by_python_type[py_type] = logical_type` + +#### Lookup methods + +| Method | Description | +|---|---| +| `get_by_logical_name(name: str) -> LogicalType \| None` | Direct dict lookup by logical type name | +| `get_by_python_type(python_type: type) -> LogicalType \| None` | Exact match first; falls back to `issubclass` scan (first registered wins) | +| `get_by_arrow_extension_name(arrow_name: str) -> LogicalType \| None` | Direct dict lookup by Arrow extension name; required for the Arrow schema read path | + +#### Removed + +- `_register_arrow_ext_type`, `_register_polars_ext_type` (synthesis logic moved to + `make_arrow_extension_type` and individual `LogicalType` implementations) +- `_ARROW_REGISTRY`, `_POLARS_REGISTRY` module-level shadow dicts +- `get_converter_for_name`, `get_converter_for_python_type` +- `has_extension_name`, `has_python_type`, `list_extension_names`, `list_python_types` + +--- + +### `extension_types/__init__.py` + +```python +from .protocols import LogicalType +from .registry import LogicalTypeRegistry, make_arrow_extension_type +from .schema_walker import ExtensionTypeInfo, walk_field, walk_schema + +default_logical_type_registry = LogicalTypeRegistry() + +__all__ = [ + "LogicalType", + "LogicalTypeRegistry", + "make_arrow_extension_type", + "default_logical_type_registry", + # PLT-1654 + "ExtensionTypeInfo", + "walk_schema", + "walk_field", +] +``` + +`default_extension_type_registry` is removed with no backward-compat alias (greenfield pre-v0.1.0). + +--- + +### `extension_types/schema_walker.py` + +No logic changes. `schema_walker.py` has no imports of `ExtensionTypeConverter` or +`ExtensionTypeRegistry` — it is self-contained around `ExtensionTypeInfo`, which is +unchanged. + +--- + +## Tests + +### `tests/test_extension_types/test_protocols.py` + +Replace `_StubConverter` with a `_StubLogicalType` conforming to the new protocol +(owns a `pa.ExtensionType` subclass and a `pl.BaseExtension` subclass). Three tests: + +- `test_protocol_is_importable` — `LogicalType` can be imported +- `test_protocol_defines_required_members` — `isinstance(stub, LogicalType)` passes +- `test_conforming_class_satisfies_protocol` — exercises all six protocol members + +### `tests/test_extension_types/test_registry.py` + +**Stub rework:** `_make_stub()` produces a `LogicalType` conforming object. Each stub creates +its own `pa.ExtensionType` subclass (via `make_arrow_extension_type`) and `pl.BaseExtension` +subclass, returned from the respective getter methods. Factory gains `logical_name` parameter. + +**Renamed/updated existing tests:** +- `test_register_stores_converter` → `test_register_stores_three_way_binding` (asserts all three + lookup methods return the registered object) +- `test_register_duplicate_raises` → becomes a triplet conflict case +- Lookup tests updated for `get_by_logical_name`, `get_by_python_type`, `get_by_arrow_extension_name` +- Tests for removed methods (`has_extension_name`, `has_python_type`, `list_*`) deleted + +**New tests for three-way binding and conflict detection:** + +| Test | What it verifies | +|---|---| +| `test_register_idempotent_same_instance` | Registering the same `LogicalType` object twice is a no-op | +| `test_triplet_conflict_same_arrow_name_raises` | Different `logical_type_name`, same Arrow ext name → `ValueError` naming conflicting key | +| `test_triplet_conflict_same_python_type_raises` | Shared `python_type` → `ValueError` | +| `test_triplet_conflict_same_logical_name_raises` | Shared `logical_type_name` → `ValueError` | +| `test_register_preexisting_arrow_type_succeeds` | Pre-registered Arrow type (`ArrowKeyError`) → no error; three-way binding stored | +| `test_register_preexisting_polars_type_succeeds` | Pre-registered Polars type (`ValueError`) → no error; three-way binding stored | +| `test_get_by_arrow_extension_name_miss` | Returns `None` for unknown arrow name | +| `test_get_by_python_type_subclass` | `issubclass` fallback still works | + +**End-to-end tests** (round-trip, Parquet) retained — stubs updated to `LogicalType` shape; +`_build_ext_array` uses `conv.get_arrow_extension_type()` directly. + +**Module-level instance test:** `default_logical_type_registry` is a `LogicalTypeRegistry`, +starts empty. + +--- + +## Out of Scope + +- Built-in `LogicalType` implementations (`PathLogicalType`, `UPathLogicalType`, + `UUIDLogicalType`) — PLT-1656 +- Wiring `LogicalTypeRegistry` into `DataContext` — PLT-1660 +- Validation of pre-existing Arrow type class on `ArrowKeyError` — PLT-1669 +- Thread-safety of the global registry instance — deferred diff --git a/superpowers/specs/2026-06-15-plt-1672-write-side-logical-type-factory-design.md b/superpowers/specs/2026-06-15-plt-1672-write-side-logical-type-factory-design.md new file mode 100644 index 00000000..86598ac3 --- /dev/null +++ b/superpowers/specs/2026-06-15-plt-1672-write-side-logical-type-factory-design.md @@ -0,0 +1,510 @@ +# PLT-1672: Write-Side Logical Type Factory Design + +**Issue:** PLT-1672 +**Date:** 2026-06-15 +**Project:** Orcapod: Arrow/Polars Extension Type Semantic Type System +**Depends on:** PLT-1668 (LogicalType/LogicalTypeRegistry — already on `extension-type-system`) + +--- + +## Overview + +The `LogicalTypeFactory` mechanism today only fires on the **read path**: when the database +read hook encounters an Arrow extension type with an unknown name, it dispatches to a factory +keyed by the `"category"` string in the extension metadata JSON. + +The **write path** has no equivalent. When a user declares a function pod whose input or output +is typed with a Python class that is not yet registered in `LogicalTypeRegistry`, there is no +mechanism to detect this and auto-register a `LogicalType` on the fly. This breaks the +ergonomic goal of "declare a dataclass, use it." + +This spec adds a second factory dispatch axis — **Python-class-keyed** — and wires a +write-side trigger at function pod declaration time. It also integrates the new extension type +system into `UniversalTypeConverter` so that complex nested types like `list[dict[A, list[B]]]` +are handled correctly without duplicating the existing recursive machinery. + +--- + +## Design decisions summary + +| Question | Decision | +|---|---| +| Factory protocol extension | Add `create_for_python_type(python_type)` as a new method; rename existing `create_logical_type` → `reconstruct_from_arrow` | +| Registration API | Extend `register_logical_type_factory` signature to accept both `category` and `python_bases` in one call | +| Complex type handling | Extend `UniversalTypeConverter` to check `LogicalTypeRegistry` first at each leaf; use `_extract_leaf_classes` to recursively unwrap generics for the registration trigger | +| Trigger location | `_FunctionPodBase.__init__()` — at pod declaration time | +| Failure mode | Hard `TypeError` at declaration time if no factory matches | +| MRO resolution | Unified MRO walk across both concrete types and factory keys; most-specific wins; concrete beats factory at same MRO level | + +--- + +## Section 1: Protocol changes — `LogicalTypeFactoryProtocol` + +**File:** `src/orcapod/extension_types/protocols.py` + +The existing `create_logical_type` method is **renamed** to `reconstruct_from_arrow` to make its +role unambiguous (read-path reconstructor from Arrow schema). A new `create_for_python_type` +method is added for the write path. + +```python +class LogicalTypeFactoryProtocol(Protocol): + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + ) -> LogicalTypeProtocol: + """Reconstruct a LogicalType from Arrow schema metadata (read path). + + Called by the registry when a schema walk encounters an extension type + whose metadata ``"category"`` value matches this factory's registered + category. All Arrow schema information is already known. + + Args: + arrow_extension_name: The Arrow extension type name from the schema. + storage_type: The underlying Arrow storage type. + metadata: Full parsed metadata JSON dict. Always contains ``"category"``. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot reconstruct a type for the given name. + """ + ... + + def create_for_python_type( + self, + python_type: type, + ) -> LogicalTypeProtocol: + """Synthesize a LogicalType for the given Python class (write path). + + Called by the registry when pod declaration encounters an unregistered + class whose MRO intersects this factory's registered ``python_bases``. + The factory derives all Arrow metadata (extension name, storage type, + metadata dict) from the Python class itself. + + The returned LogicalType must round-trip: the extension name and metadata + it produces must route back to this same factory's ``reconstruct_from_arrow`` + on a subsequent read, ensuring write → Parquet → read consistency. + + Args: + python_type: The concrete Python class to synthesize a LogicalType for. + + Returns: + A fully constructed ``LogicalTypeProtocol`` ready for registration. + + Raises: + ValueError: If this factory cannot construct a type for the given class. + """ + ... +``` + +**Breaking change:** `create_logical_type` → `reconstruct_from_arrow`. The single internal call +site in `registry.ensure_extension_type()` is updated. Any existing factory implementations +(none yet in the codebase beyond tests) must update the method name. + +The existing test stub in `test_protocols.py` (`_StubFactory.create_logical_type`) is updated +to `reconstruct_from_arrow` and a conformance test for `create_for_python_type` is added. + +--- + +## Section 2: Registry API changes — `LogicalTypeRegistry` + +**File:** `src/orcapod/extension_types/registry.py` + +### New internal state + +```python +class LogicalTypeRegistry: + def __init__(self, logical_types=None): + self._by_logical_name: dict[str, LogicalTypeProtocol] = {} + self._by_arrow_name: dict[str, LogicalTypeProtocol] = {} + self._by_python_type: dict[type, LogicalTypeProtocol] = {} + self._category_factories: dict[str, LogicalTypeFactoryProtocol] = {} # was _factories + self._python_class_factories: dict[type, LogicalTypeFactoryProtocol] = {} # new +``` + +`_factories` is renamed to `_category_factories` for clarity. No external API references it +directly. + +### `register_logical_type_factory` — extended signature + +```python +def register_logical_type_factory( + self, + factory: LogicalTypeFactoryProtocol, + *, + category: str | None = None, + python_bases: Iterable[type] = (), +) -> None: + """Register a factory on one or both dispatch axes. + + Args: + factory: The factory to register. + category: If given, registers factory as the read-side handler for + Arrow extension types whose metadata contains this category string. + Raises ``ValueError`` if a different factory is already registered + for this category. + python_bases: Zero or more Python base classes. Registers factory as + the write-side handler for each. The factory's + ``create_for_python_type`` will be called when a pod declares a + type that is a subclass of one of these bases and no concrete + ``LogicalType`` is yet registered for that type. + Raises ``ValueError`` if a different factory is already registered + for a given base. + + At least one of ``category`` or ``python_bases`` must be provided. + Registering the same factory object twice for the same key is a no-op. + """ +``` + +**Signature change:** `factory` becomes the first positional argument and `category` becomes +keyword-only. Existing call sites using `register_logical_type_factory("Dataclass", factory)` +(positional) update to `register_logical_type_factory(factory, category="Dataclass")`. + +A typical dual-axis registration (as the dataclass factory will use): + +```python +registry.register_logical_type_factory( + dataclass_factory, + category="Dataclass", + python_bases=[DataclassSentinelABC], +) +``` + +### `ensure_extension_type` — one-line update + +The internal call changes from `factory.create_logical_type(...)` to +`factory.reconstruct_from_arrow(...)`. No other logic changes. + +### New: `ensure_logical_type_for_python_class` + +```python +def ensure_logical_type_for_python_class( + self, + python_type: type, +) -> LogicalTypeProtocol: + """Ensure a LogicalType exists for python_type, synthesizing via factory if needed. + + This is the write-side counterpart to ``ensure_extension_type`` (the read-side + trigger). It is called at function pod declaration time for every non-native + leaf class extracted from the pod's input and output schemas. + + Resolution algorithm (unified MRO walk): + + 1. Walk ``python_type.__mro__``. At each MRO step, check: + - ``_by_python_type`` for a concrete registered ``LogicalType`` + - ``_python_class_factories`` for a registered factory + Track the first (most-specific) hit in each dict separately. + + 2. After the MRO walk, if no factory was found in step 1, do a fallback + ``issubclass`` scan over ``_python_class_factories`` keys. This catches + ABCs that use ``__subclasshook__`` for structural dispatch (e.g. a + ``_DataclassSentinelABC`` whose hook returns ``is_dataclass(C)``). + + 3. Resolution rule: + - If only a concrete type found → return it immediately (O(1) after first hit). + - If only a factory found → call ``factory.create_for_python_type(python_type)``, + register the result via ``register_logical_type()``, return it. + Registration caches in ``_by_python_type[python_type]`` — next lookup is O(1). + - If both found at the same MRO level (same class in MRO) → concrete wins. + - If concrete is more specific (lower MRO index) → return concrete. + - If factory is more specific (lower MRO index) → synthesize and register. + + 4. If nothing found (no concrete type, no factory): raise ``TypeError``. + + Args: + python_type: The Python class to resolve. + + Returns: + The registered or newly synthesized ``LogicalTypeProtocol``. + + Raises: + TypeError: If no ``LogicalType`` and no factory is found for ``python_type``. + Message includes guidance on how to register a factory. + """ +``` + +**Caching:** once a factory synthesizes a `LogicalType` for a concrete class and +`register_logical_type` stores it in `_by_python_type[python_type]`, all future calls for that +exact class are O(1) exact-match dict lookups. No factory call, no MRO walk. This per-process +cache is intentionally shared with the read-side cache — they are one and the same +`_by_python_type` dict. + +--- + +## Section 3: Complex type handling — `_extract_leaf_classes` and `UniversalTypeConverter` + +Handling complex nested annotations like `list[dict[A, list[B]]]` requires two complementary +pieces: recursive leaf extraction for the **registration phase**, and a priority check in +`UniversalTypeConverter` for the **encoding phase**. Crucially, the existing recursive machinery +in `UniversalTypeConverter` already handles generic nesting — we tap into it rather than +duplicate it. + +### 3a: `_extract_leaf_classes` — recursive annotation unwrapper + +**File:** `src/orcapod/extension_types/type_utils.py` (new module) + +```python +def _extract_leaf_classes(annotation: Any) -> Iterator[type]: + """Recursively yield all concrete leaf Python classes from a type annotation. + + Unwraps generic aliases (``list[T]``, ``dict[K, V]``, ``Optional[T]``, + ``Union[A, B]``, etc.) using ``typing.get_origin`` / ``typing.get_args`` + and yields every non-generic, non-None leaf class found. + + Examples: + ``list[MyEvent]`` → ``[MyEvent]`` + ``dict[str, MyEvent]`` → ``[str, MyEvent]`` + ``list[dict[A, list[B]]]`` → ``[A, B]`` + ``Optional[MyEvent]`` → ``[MyEvent]`` + ``Union[A, B, None]`` → ``[A, B]`` + ``int`` → ``[int]`` + """ +``` + +Used by `_trigger_write_side_registration` (Section 4) to discover all leaf classes in a schema +column's type annotation before attempting factory dispatch. The function lives in +`extension_types/type_utils.py` so it is importable by both `function_pod.py` and future +callers without a circular-import risk. + +### 3b: `UniversalTypeConverter` — priority check for `LogicalTypeRegistry` + +**File:** `src/orcapod/semantic_types/universal_converter.py` + +`UniversalTypeConverter` gains an optional `logical_type_registry` attribute, injected from +`DataContext` at construction/wiring time: + +```python +class UniversalTypeConverter: + def __init__(self, ..., logical_type_registry: LogicalTypeRegistry | None = None): + ... + self._logical_type_registry = logical_type_registry +``` + +In `_convert_python_to_arrow()`, one new check is inserted **before** the existing +`semantic_registry` check: + +```python +def _convert_python_to_arrow(self, python_type: type) -> pa.DataType: + # ── NEW: check LogicalTypeRegistry first (extension-type identity) ────── + if self._logical_type_registry is not None: + lt = self._logical_type_registry.get_by_python_type(python_type) + if lt is not None: + return lt.get_arrow_extension_type() + + # ── EXISTING: semantic_registry (old shape-based identity) ─────────────── + if self.semantic_registry: + converter = self.semantic_registry.get_converter_for_python_type(python_type) + if converter: + return converter.arrow_struct_type + + # ── EXISTING: dataclass encoding, generic handling, etc. ───────────────── + ... +``` + +This is an **additive, non-breaking change**. The old `semantic_registry` and dataclass encoding +paths remain completely intact and serve as the fallback during the parallel build phase. Once +PLT-1660 removes the old system, those fallback paths are deleted. + +**Why `get_by_python_type()` and not `ensure_logical_type_for_python_class()`** at this call +site: by the time `UniversalTypeConverter` runs (encoding phase), the registration trigger at +pod declaration time has already called `ensure_logical_type_for_python_class` for every leaf +class. The converter therefore only needs a read-only lookup — no synthesis, no side effects. +If a type somehow arrives unregistered at encoding time, it falls through to the old system +rather than raising, preserving parallel-build safety. + +### 3c: `DataContext` wiring + +**File:** `src/orcapod/contexts/` (wherever `DataContext` is constructed) + +When a `DataContext` is constructed, its `logical_type_registry` is passed to its +`type_converter`: + +```python +# In DataContext construction / post-init: +self.type_converter._logical_type_registry = self.logical_type_registry +``` + +This is the only place where the two systems are connected. No other wiring is needed. + +--- + +## Section 4: Trigger point — `_FunctionPodBase.__init__()` + +**File:** `src/orcapod/core/function_pod.py` + +A module-level helper is added and called from `_FunctionPodBase.__init__()` after the data +function is assigned. It uses `_extract_leaf_classes` to handle complex nested annotations +before calling `ensure_logical_type_for_python_class` for each leaf. + +```python +# Types that Arrow handles natively without a LogicalType +_ARROW_NATIVE_TYPES: frozenset[type] = frozenset({ + int, float, str, bytes, bool, type(None), +}) + + +def _trigger_write_side_registration( + input_schema: Schema, + output_schema: Schema, + registry: LogicalTypeRegistry | None, +) -> None: + """Walk pod schemas and ensure a LogicalType is registered for every non-native leaf class. + + Recursively unwraps generic annotations (``list[T]``, ``dict[K,V]``, etc.) to + extract leaf classes, then triggers factory synthesis for any not yet registered. + Called once at pod declaration time. + + Arrow-native types (int, str, etc.) are skipped. Already-registered types are + skipped via a fast O(1) dict check. Unregistered non-native types trigger factory + synthesis via ``ensure_logical_type_for_python_class``. Raises ``TypeError`` if + no factory is found — this is an intentional hard error at declaration time. + + Args: + input_schema: The pod's input data schema (column name → Python type annotation). + output_schema: The pod's output data schema. + registry: The LogicalTypeRegistry from the pod's DataContext. No-op if None. + """ + if registry is None: + return + for schema in (input_schema, output_schema): + for annotation in schema.values(): + for leaf_class in _extract_leaf_classes(annotation): + if leaf_class in _ARROW_NATIVE_TYPES: + continue + if registry.get_by_python_type(leaf_class) is not None: + continue # already registered — O(1) cache hit, skip MRO walk + registry.ensure_logical_type_for_python_class(leaf_class) + # TypeError propagates if no factory matches — intentional +``` + +In `_FunctionPodBase.__init__()`: + +```python +self._data_function = data_function +_trigger_write_side_registration( + data_function.input_data_schema, + data_function.output_data_schema, + self.data_context.logical_type_registry, +) +``` + +**Single chokepoint:** every function pod, whether `FunctionPod` or `FunctionNode`, is +constructed through `_FunctionPodBase.__init__()`. There is no other code path to reach. + +--- + +## Section 5: Failure modes + +**No factory found at pod declaration time:** + +``` +TypeError: No LogicalType or LogicalTypeFactory is registered for type +'myapp.models.Foo'. +To handle this type, register a factory for its base class on the registry: + registry.register_logical_type_factory(factory, python_bases=[]) +Or register a concrete LogicalType directly: + registry.register_logical_type(my_logical_type) +``` + +This error is raised from `ensure_logical_type_for_python_class` and surfaces at the +`_FunctionPodBase.__init__()` call site. There is no fallback, no implicit pickle, no silent +pass-through. The failure is deliberate and loud. + +**Registry is None:** `_trigger_write_side_registration` is a no-op. This handles contexts +without type registration (e.g. test environments that construct pods without a full +DataContext). + +**Unregistered type reaching encoding:** If a type somehow bypasses pod declaration and reaches +`UniversalTypeConverter._convert_python_to_arrow()` without being registered, `get_by_python_type()` +returns `None` and the converter falls through to the old `semantic_registry` / dataclass +encoding path. This is intentional parallel-build safety: the new system is higher-priority but +not exclusive until PLT-1660. + +--- + +## Section 6: Symmetry with the read side + +By protocol contract, `create_for_python_type(T)` must produce a `LogicalType` whose Arrow +extension name and metadata JSON are identical to what `reconstruct_from_arrow` expects to +receive when reading that data back. Concretely for the dataclass factory: + +| Direction | Method | Extension name | Metadata | +|---|---|---|---| +| Write | `create_for_python_type(MyEvent)` | `"myapp.models.MyEvent"` | `{"category": "Dataclass"}` | +| Read | `reconstruct_from_arrow("myapp.models.MyEvent", struct_type, {"category": "Dataclass"})` | same | same | + +The registry routes the read path via `_category_factories["Dataclass"]` and the write path via +`_python_class_factories[DataclassSentinelABC]` — the same factory object, different dispatch +keys. Round-trip consistency is enforced by integration tests (write → Parquet → read), not by +the registry itself. + +--- + +## Section 7: Built-in types (Path, UPath, UUID) — confirmed unaffected + +Built-ins are registered as concrete `LogicalType` instances against their exact Python types +(`pathlib.Path`, `upath.UPath`, `uuid.UUID`) in the `DataContext` at startup. + +At **registration time** (pod declaration): `_extract_leaf_classes` yields `pathlib.Path` from +an annotation like `pathlib.Path`; `registry.get_by_python_type(pathlib.Path)` hits the exact- +match dict immediately → skipped. + +At **encoding time** (UniversalTypeConverter): `get_by_python_type(pathlib.Path)` returns +`LogicalPath` → `lt.get_arrow_extension_type()` returns the extension Arrow type. The old +`semantic_registry` check for Path is never reached. + +Built-ins continue to work, and are now encoded via the new extension type (not the old struct +shape). ✓ + +--- + +## Section 8: What this issue does NOT implement + +The following are explicitly deferred: + +- **Dataclass factory (`orcapod.dataclass`):** PLT-1657 implements the concrete factory and + registers it via `register_logical_type_factory(factory, category="Dataclass", python_bases=[DataclassSentinelABC])`. + PLT-1657 also defines `DataclassSentinelABC` (the ABC with `__subclasshook__` that returns + `is_dataclass(C)`). PLT-1672 defines the slot; PLT-1657 fills it. +- **Pydantic factory:** future issue. The framework accommodates it by design. +- **Picklable factory as fallback:** deferred. The failure-mode section deliberately makes + no-match a hard error for now. +- **Removal of old `semantic_registry` / dataclass encoding paths:** PLT-1660 only. + +--- + +## Implementation scope + +All changes are additive. No existing code is deleted. + +### Source files + +| File | Change | +|---|---| +| `src/orcapod/extension_types/protocols.py` | Rename `create_logical_type` → `reconstruct_from_arrow`; add `create_for_python_type` to `LogicalTypeFactoryProtocol` | +| `src/orcapod/extension_types/registry.py` | Rename `_factories` → `_category_factories`; add `_python_class_factories`; extend `register_logical_type_factory` signature; update `ensure_extension_type` call site; add `ensure_logical_type_for_python_class` | +| `src/orcapod/extension_types/type_utils.py` | New module: `_extract_leaf_classes(annotation)` | +| `src/orcapod/semantic_types/universal_converter.py` | Add optional `logical_type_registry` param; insert `LogicalTypeRegistry.get_by_python_type()` check before `semantic_registry` check in `_convert_python_to_arrow` | +| `src/orcapod/contexts/` | Wire `DataContext.logical_type_registry` into `DataContext.type_converter._logical_type_registry` at construction | +| `src/orcapod/core/function_pod.py` | Add `_ARROW_NATIVE_TYPES`, `_trigger_write_side_registration`; call from `_FunctionPodBase.__init__()` | + +### Test files + +| File | Change | +|---|---| +| `tests/test_extension_types/test_protocols.py` | Update `_StubFactory.create_logical_type` → `reconstruct_from_arrow`; add `create_for_python_type` stub and conformance test | +| `tests/test_extension_types/test_registry.py` | Update `register_logical_type_factory` call sites; add tests for `ensure_logical_type_for_python_class` (MRO walk, factory synthesis, caching, TypeError) | +| `tests/test_extension_types/test_type_utils.py` | New: tests for `_extract_leaf_classes` covering primitives, `list[T]`, `dict[K,V]`, `Optional[T]`, `Union`, deeply nested | +| `tests/test_semantic_types/test_universal_converter.py` | Add tests for `logical_type_registry` priority check: registered types return extension Arrow type, unregistered fall through to old system | +| `tests/test_core/function_pod/test_write_side_registration.py` | New: end-to-end tests for pod declaration triggering factory synthesis; nested types (`list[MyClass]`); hard error when no factory matches | + +--- + +## PLT-1660 cleanup items (deferred) + +- Remove `semantic_registry` fallback path from `UniversalTypeConverter._convert_python_to_arrow()` (replaced entirely by `logical_type_registry`) +- Remove old `semantic_registry` / `dataclass_encoding` integration from `UniversalTypeConverter` diff --git a/superpowers/specs/2026-06-16-plt-1705-type-registration-spine-refactor.md b/superpowers/specs/2026-06-16-plt-1705-type-registration-spine-refactor.md new file mode 100644 index 00000000..38d07231 --- /dev/null +++ b/superpowers/specs/2026-06-16-plt-1705-type-registration-spine-refactor.md @@ -0,0 +1,362 @@ +# PLT-1705: Type Registration Spine Refactor and DataclassHandlerFactory + +**Issue:** PLT-1705 +**Date:** 2026-06-16 +**Project:** Orcapod: Arrow/Polars Extension Type Semantic Type System +**Branch:** `eywalker/plt-1705-refactor-type-registration-spine-and-implement` + +--- + +## Overview + +`UniversalTypeConverter` becomes the **single re-entry point** for all Python ↔ Arrow type +registration and value conversion. `LogicalTypeRegistry` becomes its private implementation +detail. Factories and logical types are thin leaf nodes with no upward dependencies beyond +the `TypeConverterProtocol`. + +This supersedes PLT-1657 and closes PR #174 without merging. `DataclassHandlerFactory` is +implemented from scratch on the refined architecture. + +--- + +## Core design principle + +`UniversalTypeConverter` owns all traversal of Python annotations and Arrow types in both +directions. Everything that used to be split across `LogicalTypeRegistry.ensure_*` methods +moves into two symmetric public methods on the converter: + +| Direction | Method | Input | Output | +|---|---|---|---| +| Write (register) | `register_python_class(annotation)` | Python type annotation | `pa.DataType` | +| Read (register) | `register_storage_type(arrow_type)` | `pa.DataType` | `pa.DataType` | + +Both methods walk their input recursively, register any new logical types encountered as a +side effect, and return the normalised Arrow type with extension types embedded. + +--- + +## Section 1: Protocol changes (`extension_types/protocols.py`) + +### New: `TypeConverterProtocol` + +Minimal protocol exposing what factories and logical types need from the converter. +Placed in `protocols.py` to avoid circular imports. + +```python +class TypeConverterProtocol(Protocol): + def register_python_class(self, annotation: Any) -> pa.DataType: ... + def register_storage_type(self, arrow_type: pa.DataType) -> pa.DataType: ... + def python_to_storage(self, value: Any, annotation: Any) -> Any: ... + def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: ... +``` + +### Updated: `LogicalTypeFactoryProtocol` + +`supports_class` is added (write-side probe). Both factory methods receive `converter` +instead of `registry` and `ResolutionContext`. + +```python +class LogicalTypeFactoryProtocol(Protocol): + def supports_class(self, python_type: type) -> bool: ... + + def create_for_python_type( + self, + python_type: type, + converter: TypeConverterProtocol, + ) -> LogicalTypeProtocol: ... + + def reconstruct_from_arrow( + self, + arrow_extension_name: str, + storage_type: pa.DataType, + metadata: dict[str, Any], + converter: TypeConverterProtocol, + ) -> LogicalTypeProtocol: ... +``` + +### Updated: `LogicalTypeProtocol` + +Value conversion methods receive `converter`. Built-in implementations accept and ignore it; +`DataclassLogicalType` uses it for per-field recursion. + +```python +def python_to_storage(self, value: Any, converter: TypeConverterProtocol) -> Any: ... +def storage_to_python(self, storage_value: Any, converter: TypeConverterProtocol) -> Any: ... +``` + +--- + +## Section 2: Registry becomes a thin data store (`extension_types/registry.py`) + +### Public surface retained + +- `register_logical_type(lt)` +- `register_logical_type_factory(factory, *, category, python_bases)` +- `get_by_python_type`, `get_by_arrow_extension_name`, `get_by_logical_name` + +### Removed + +- `ensure_logical_type_for_python_class` — logic moves into `UniversalTypeConverter.register_python_class` +- `ensure_extension_type` — logic moves into `UniversalTypeConverter.register_storage_type` + +The registry is never passed to factories. It is an internal data structure of the converter. + +--- + +## Section 3: `UniversalTypeConverter` — single re-entry point (`semantic_types/universal_converter.py`) + +### `register_python_class(annotation) -> pa.DataType` + +Write-side re-entry point. Traverses Python annotations recursively. + +- **Primitive** → return from type map directly (no side effects) +- **Registry hit** (concrete type already in `_registry`) → return `lt.get_arrow_extension_type()` +- **Generics** (recurse structurally): + - `list[T]` → `pa.large_list(register_python_class(T))` + - `dict[K, V]` → `pa.large_list(pa.struct([field("key", K), field("value", V)]))` + - `Optional[T]` / `T | None` → `register_python_class(T)` (nullability at field level) + - `set[T]` → `pa.large_list(register_python_class(T))` +- **Registry miss** on concrete type → MRO walk over `_python_class_factories`, call + `factory.supports_class(type)` to find match, call + `factory.create_for_python_type(type, converter=self)`, register result, return extension type +- **Cycle detection** via `_in_progress: set[type]` instance variable: if a type is already + being synthesised, raise `TypeError` + +### `register_storage_type(arrow_type: pa.DataType) -> pa.DataType` + +Read-side re-entry point. Traverses Arrow types recursively, bottom-up. + +- **Primitive** → return as-is +- **`pa.ExtensionType`**: + - Registry hit → return immediately (no-op) + - Registry miss → recurse into `storage_type` first (bottom-up resolution), then parse + metadata, find factory by `"category"` key, call + `factory.reconstruct_from_arrow(name, resolved_storage_type, metadata, converter=self)`, + register result, return extension type +- **`pa.StructType`** → recurse into each field, reassemble with resolved field types +- **`pa.ListType` / `pa.LargeListType`** → recurse into value type, reassemble + +The bottom-up order guarantees that when a factory receives `storage_type`, all nested +extension types within it are already registered and resolved. + +**Example** — `my_data.Dataset` (dataclass) wrapping `struct{a: i32, b: list[orcapod.uuid]}`: + +``` +register_storage_type(my_data.Dataset ext → struct{a:i32, b:list[large_binary w/ orcapod.uuid]}) + recurse into storage: + register_storage_type(struct{a:i32, b:list[orcapod.uuid ext]}) + field a: i32 → i32 + field b: register_storage_type(list[orcapod.uuid ext]) + register_storage_type(orcapod.uuid ext) → registry hit → orcapod.uuid ext + → list[orcapod.uuid ext] + → struct{a:i32, b:list[orcapod.uuid ext]} ← resolved storage type + my_data.Dataset: registry miss + → factory.reconstruct_from_arrow("my_data.Dataset", + struct{a:i32, b:list[orcapod.uuid ext]}, ← resolved, not raw + {"category":"orcapod.dataclass"}, converter=self) + → register → return my_data.Dataset ext type +``` + +### Value conversion methods + +```python +def python_to_storage(self, value: Any, annotation: Any) -> Any: ... +def storage_to_python(self, storage_value: Any, annotation: Any) -> Any: ... +``` + +Thin wrappers over the existing `get_python_to_arrow_converter` / +`get_arrow_to_python_converter` machinery. For extension types, the generated converter +calls `lt.python_to_storage(value, converter=self)` / `lt.storage_to_python(value, converter=self)`. +These are needed by `DataclassLogicalType` for per-field delegation back to the converter. + +### Registration pass-throughs + +```python +def register_logical_type(self, lt: LogicalTypeProtocol) -> None: + self._registry.register_logical_type(lt) + +def register_logical_type_factory( + self, factory: LogicalTypeFactoryProtocol, + *, category: str | None = None, + python_bases: Iterable[type] = (), +) -> None: + self._registry.register_logical_type_factory(factory, category=category, python_bases=python_bases) +``` + +External code that previously used `context.logical_type_registry.register_*` now uses +`context.type_converter.register_*`. + +### `ensure_types_registered_for_schemas` (simplified) + +```python +def ensure_types_registered_for_schemas(self, *schemas: Schema) -> None: + for schema in schemas: + for annotation in schema.values(): + self.register_python_class(annotation) +``` + +### Removals + +- `semantic_registry` constructor parameter and all its usage in `_convert_python_to_arrow` + / `_convert_arrow_to_python` — removed +- All `dataclass_encoding` imports and the old sentinel-based dataclass struct path — removed; + `dataclass_encoding.py` is deleted + +--- + +## Section 4: `DataclassHandlerFactory` (`extension_types/dataclass_handler.py` — new file) + +### `DataclassLogicalType` + +Thin holder of identity, schema, and field annotations. No pre-baked converters. + +```python +def python_to_storage(self, value, converter): + return { + name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations + } + +def storage_to_python(self, storage_value, converter): + return self._python_type(**{ + name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations + }) +``` + +`_field_annotations: list[tuple[str, type]]` stores `(field_name, python_annotation)` pairs. +No Arrow types stored in the logical type — the converter owns all Arrow-level reasoning. + +### `DataclassHandlerFactory` + +Stateless. Approximately 30 lines of logic. + +**`supports_class(python_type)`**: `return dataclasses.is_dataclass(python_type)` + +**`create_for_python_type(python_type, converter)`** (write path): +1. Reject local / unnamed classes (no stable FQCN) with hard `ValueError` +2. `get_type_hints(python_type)` to obtain field annotations +3. Iterate `dataclasses.fields(python_type)`; for each field: + `arrow_type = converter.register_python_class(annotation)` — all traversal delegated to converter +4. Assemble `pa.struct([pa.field(name, arrow_type), ...])` and `field_annotations` list +5. Return `DataclassLogicalType(fqcn, python_type, storage_type, field_annotations)` + +`dict[K, V]` fields encode as `list[struct{key:K, value:V}]` — owned entirely by +`converter.register_python_class`, no special handling in the factory. + +**`reconstruct_from_arrow(name, storage_type, metadata, converter)`** (read path): +1. Import class from `name` (FQCN) using longest-prefix module walk — hard `ImportError` if not found +2. `get_type_hints(imported_cls)` → build `field_annotations` matched against `storage_type`'s fields +3. `storage_type` is already resolved (sub-extension types embedded, bottom-up by `register_storage_type`) +4. Factory does **not** call `converter.register_storage_type` for sub-fields — already done +5. Return `DataclassLogicalType(name, imported_cls, storage_type, field_annotations)` + +--- + +## Section 5: `DataContext` and context wiring + +**`contexts/core.py`**: `logical_type_registry: LogicalTypeRegistry` field removed. +`type_converter` is the sole public API for type operations. + +**`contexts/__init__.py`**: `get_default_logical_type_registry()` removed. + +**`contexts/registry.py`**: `_create_context_from_spec` no longer passes `logical_type_registry` +to `DataContext`. The `LogicalTypeRegistry` is constructed as a nested object inside +`type_converter`'s config — it never appears as a top-level `ref_lut` entry. + +**`contexts/data/v0.1.json`**: +- Top-level `logical_type_registry` key removed +- Registry construction (with built-in `logical_types` list) moves into `type_converter`'s `_config` +- `semantic_registry` reference removed from `type_converter`'s `_config` + +**`contexts/data/schemas/context_schema.json`**: +- Remove `logical_type_registry` from required fields and properties + +--- + +## Section 6: `database_hooks.py` and `ExtensionAwareDatabase` + +**`register_discovered_extensions`** simplifies to: + +```python +def register_discovered_extensions(converter: TypeConverterProtocol, schema: pa.Schema) -> None: + for field in schema: + converter.register_storage_type(field.type) +``` + +The schema walker's depth-first extension-field extraction is no longer needed here — +`register_storage_type` owns that traversal. `schema_walker.py` itself is retained (other +callers may use it). + +**`databases/extension_aware_database.py`**: takes `converter: TypeConverterProtocol` +(was `registry: LogicalTypeRegistry`). Internal call sites updated accordingly. + +--- + +## Section 7: Deletions, built-in updates, and testing + +### Deleted files + +| File | Reason | +|---|---| +| `semantic_types/dataclass_encoding.py` | Superseded by `DataclassHandlerFactory` + converter | + +### Files with removed usages + +| File | What is removed | +|---|---| +| `semantic_types/universal_converter.py` | `semantic_registry` usage, `dataclass_encoding` imports | +| `extension_types/type_utils.py` | `extract_leaf_classes` made private (`_extract_leaf_classes`) or removed; traversal lives in converter | + +### Built-in logical types (`builtin_logical_types.py`) + +`LogicalPath`, `LogicalUUID`, `LogicalUPath` — add `converter` param (accepted, ignored) to +`python_to_storage` and `storage_to_python` on all three, for protocol conformance. + +### Test files + +| File | Change | +|---|---| +| `tests/test_extension_types/test_protocols.py` | Add `TypeConverterProtocol` conformance; update factory/logical-type stubs for new signatures | +| `tests/test_extension_types/test_registry.py` | Remove `ensure_*` tests; add converter pass-through tests | +| `tests/test_extension_types/test_builtin_logical_types.py` | Update `python_to_storage` / `storage_to_python` call sites to pass a converter stub | +| `tests/test_extension_types/test_dataclass_handler.py` | **New**: `DataclassLogicalType` unit tests; factory write path (flat, list, dict, nested); read path; local-class rejection; cycle detection; `supports_class`; Arrow round-trips | +| `tests/test_semantic_types/test_universal_converter.py` | Add `register_python_class` tests (primitives, generics, factory dispatch, cycle detection); `register_storage_type` tests (primitives, extension types, struct/list recursion); `python_to_storage` / `storage_to_python` for logical type dispatch | + +--- + +## File-by-file change summary + +| File | Change | +|---|---| +| `extension_types/protocols.py` | Add `TypeConverterProtocol`; update `LogicalTypeFactoryProtocol` (add `supports_class`, `converter` param); update `LogicalTypeProtocol` (`converter` param on conversion methods) | +| `extension_types/registry.py` | Remove `ensure_logical_type_for_python_class`, `ensure_extension_type` | +| `extension_types/builtin_logical_types.py` | Add `converter` param (ignored) to `python_to_storage` / `storage_to_python` | +| `extension_types/type_utils.py` | `extract_leaf_classes` made private or removed | +| `extension_types/dataclass_handler.py` | **New**: `DataclassLogicalType` + `DataclassHandlerFactory` | +| `semantic_types/universal_converter.py` | Add `register_python_class`, `register_storage_type`, `python_to_storage`, `storage_to_python`, `register_logical_type`, `register_logical_type_factory`; remove `semantic_registry` usage; remove `dataclass_encoding` usage; simplify `ensure_types_registered_for_schemas` | +| `semantic_types/dataclass_encoding.py` | **Deleted** | +| `extension_types/database_hooks.py` | `register_discovered_extensions` takes converter, calls `register_storage_type` per field | +| `databases/extension_aware_database.py` | Takes `converter` instead of `registry` | +| `contexts/core.py` | Remove `logical_type_registry` field | +| `contexts/__init__.py` | Remove `get_default_logical_type_registry` | +| `contexts/registry.py` | Stop passing `logical_type_registry` to `DataContext` | +| `contexts/data/v0.1.json` | Move registry construction inside `type_converter` config; remove `semantic_registry` from `type_converter` config | + +--- + +## Out of scope + +- Wiring `DataclassHandlerFactory` into the default context — PLT-1701 +- Nested extension types inside struct sub-fields (self-describing nesting) — PLT-1700 (v0.2) +- `dict[K, V]` as `list[struct{key, value}]` — **in scope** (owned by converter, zero factory logic) + +## Note: registered logical types as dataclass field types work naturally + +Because `DataclassHandlerFactory` delegates all per-field type resolution to +`converter.register_python_class`, dataclass fields typed as registered logical types +(e.g. `pathlib.Path`, `uuid.UUID`, `upath.UPath`) work without any special handling. +`register_python_class` hits the registry immediately for pre-registered types and returns +their Arrow extension type. Value conversion dispatches through the logical type's +`python_to_storage` / `storage_to_python` methods. This was listed as a follow-up gap in +PLT-1657, but is resolved by the new architecture at no extra cost. diff --git a/superpowers/specs/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md b/superpowers/specs/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md new file mode 100644 index 00000000..26114eb9 --- /dev/null +++ b/superpowers/specs/2026-06-17-plt-1720-register-python-class-storage-type-cleanup.md @@ -0,0 +1,253 @@ +# PLT-1720: register_python_class storage-type cleanup + registration completeness fix + +**Date:** 2026-06-17 +**Issue:** PLT-1720 — Cleanup: register_python_class should return plain storage type, not extension type +**Branch:** `eywalker/plt-1720-cleanup-register_python_class-should-return-plain-storage` + +--- + +## Problem + +`register_python_class(annotation)` currently returns a `pa.ExtensionType` for annotations +that have a registered logical type. Callers that build Arrow struct fields must immediately +strip that extension type back to plain storage via `_strip_ext_to_storage`, because Arrow and +Polars cannot construct arrays with `pa.ExtensionType` nodes inside struct fields (ET1 in +`DESIGN_ISSUES.md`). + +This creates an API impedance mismatch: the return value of `register_python_class` cannot +be used where struct fields are needed, which is its primary call site. + +A second, related problem: `DataclassLogicalTypeFactory.reconstruct_from_arrow` (the Parquet +read path) does not call `converter.register_python_class` for its field annotations. This +means that in a fresh process, a nested dataclass (e.g. `Inner` inside `Outer`) is never +registered when reading `Outer` from Parquet. Value conversion for `Inner` then fails with +`ValueError("Unsupported Python type: Inner.")`. + +--- + +## Design invariant + +**Registration completeness**: when a logical type is registered by any path, all nested +logical types it depends on must also be registered as a consequence. This is a contract on +`LogicalTypeFactoryProtocol`: both `create_for_python_type` and `reconstruct_from_arrow` +must leave the converter in a state where every logical type the returned `LogicalTypeProtocol` +depends on is also registered before the method returns. + +How a factory satisfies this invariant is an implementation detail and is not prescribed here. +A future factory could, for example, embed enough information in its Arrow extension metadata +to reconstruct and register all inner types directly from the metadata, without ever importing +the Python class. That would be equally valid. + +For `DataclassLogicalTypeFactory` specifically, the current implementation satisfies the +invariant by calling `converter.register_python_class(annotation)` for each field annotation +in both `create_for_python_type` (which already did this to build struct fields) and the +newly updated `reconstruct_from_arrow` (which discards the return value and uses only the +registration side effect). This is the natural choice because the dataclass field annotations +are available via `typing.get_type_hints` and `register_python_class` already handles +recursive registration correctly. + +--- + +## Contract changes + +| Function | Before | After | +|---|---|---| +| `register_python_class(annotation)` | Returns `pa.ExtensionType` for registered classes; may return extension types nested inside struct/list fields | Returns storage-safe `pa.DataType`: may be extension type at the top level for registered types, but struct/list fields always contain plain (non-extension) types at every depth | +| `register_storage_type(arrow_type)` | Returns `pa.DataType`; may return extension types nested inside struct/list fields | Returns storage-safe `pa.DataType`: may be extension type at the top level, but struct/list fields always contain plain (non-extension) types at every depth | +| `reconstruct_from_arrow(...)` | Does not register nested types | Must ensure all nested types are registered before returning (mechanism is factory-specific) | + +`python_type_to_arrow_type(annotation)` is **unchanged** — it still returns `pa.ExtensionType` +for registered classes, used for top-level column schema via `python_schema_to_arrow_schema`. + +--- + +## Changes + +### 1. `extension_types/protocols.py` — TypeConverterProtocol + +- `register_python_class`: update docstring — "return the storage-safe Arrow type: may be + extension type at the top level for registered types, but struct/list fields are always plain" +- `register_storage_type`: update docstring — "traverse an Arrow type bottom-up, + registering any extension types encountered; return a storage-safe ``pa.DataType`` + (may be extension type at the top level, but struct/list fields contain only plain types)" + +### 2. `semantic_types/universal_converter.py` — UniversalTypeConverter + +**`_register_python_class_impl`**: the two return sites that previously returned the extension +type now return it unchanged (no `.storage_type` strip). The storage-safe guarantee is satisfied +at the top level because `DataclassLogicalType` and other factories always build their struct +storage with plain field types: + +```python +# Registry hit — return ext type directly (already storage-safe by factory invariant) +lt = self._logical_type_registry.get_by_python_type(annotation) +if lt is not None: + return lt.get_arrow_extension_type() # unchanged from current behaviour + +# After factory dispatch — same +lt = factory.create_for_python_type(annotation, converter=self) +self._logical_type_registry.register_logical_type(lt) +return lt.get_arrow_extension_type() # unchanged from current behaviour +``` + +The container branches (`list[T]`, `set[T]`, `dict[K,V]`, `Optional[T]`) recurse through +`self.register_python_class(...)` and receive a potentially extension-typed result. +For `Optional[T]` the result is returned unchanged (nullability is a field-level concern). +For `list[T]`, `set[T]`, and `dict[K,V]`, if the element/key/value resolves to an extension +type, a `ValueError` is raised rather than silently stripping the extension type — this is +the ET2 policy (fail loudly at schema-construction time). See ET2 in `DESIGN_ISSUES.md` and +PLT-1732 for the planned `ListLogicalType` fix. + +End-to-end examples: +- `list[UUID]` → raises `ValueError` (ET2: UUID is a logical type; use a direct UUID column or wrap in a dataclass field) +- `dict[str, UUID]` → raises `ValueError` (ET2: same reason) +- `list[int]` → `pa.large_list(pa.int64())` (plain types are fine) +- `Optional[UUID]` → `orcapod.uuid` extension type (same as `UUID` directly; `Optional[T]` is a nullability wrapper that delegates to `register_python_class(T)` unchanged) +- `UUID` directly → `orcapod.uuid` extension type (top-level; storage is `pa.large_binary()`) + +`_convert_python_to_arrow` (used by `python_type_to_arrow_type`) is not touched. + +**`register_storage_type`**: updated from "traverse + rebuild (may preserve nested extension types)" to "traverse + rebuild with storage-safe guarantee (strip extension types from struct/list fields)": + +```python +def register_storage_type(self, arrow_type: "pa.DataType") -> "pa.DataType": + if isinstance(arrow_type, pa.ExtensionType): + ext_name = arrow_type.extension_name + if self._logical_type_registry is not None: + if self._logical_type_registry.get_by_arrow_extension_name(ext_name) is not None: + lt = self._logical_type_registry.get_by_arrow_extension_name(ext_name) + return lt.get_arrow_extension_type() # already registered, return ext type + self.register_storage_type(arrow_type.storage_type) # bottom-up first + raw_meta = arrow_type.__arrow_ext_serialize__() + return self.register_arrow_extension(ext_name, raw_meta or None, arrow_type.storage_type) + if pa.types.is_struct(arrow_type): + resolved_fields = [] + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + resolved = self.register_storage_type(field.type) + if isinstance(resolved, pa.ExtensionType): + resolved = resolved.storage_type # strip: ET1 forbids ext inside struct fields + resolved_fields.append(pa.field(field.name, resolved, nullable=field.nullable, metadata=field.metadata)) + return pa.struct(resolved_fields) + if pa.types.is_large_list(arrow_type) or pa.types.is_list(arrow_type): + vf = arrow_type.value_field + resolved = self.register_storage_type(vf.type) + if isinstance(resolved, pa.ExtensionType): + resolved = resolved.storage_type # strip: ET1 forbids ext inside list value type + return pa.large_list(pa.field(vf.name, resolved, nullable=vf.nullable, metadata=vf.metadata)) + return arrow_type # primitives: return unchanged +``` + +The storage-safe guarantee: a top-level extension type may be returned (the caller can use it as a column type), but any struct or list the returned type contains will never have extension type nodes in their fields/value types. + +### 3. `extension_types/dataclass_logical_type_factory.py` + +**`_strip_ext_to_storage`**: deleted entirely (private, not exported, no longer called). + +**`create_for_python_type`**: replace the recursive `_strip_ext_to_storage` call with a trivial +one-liner that strips only the top-level extension type (the storage-safe guarantee from +`register_python_class` ensures `.storage_type` is always clean — no further recursion needed): + +```python +arrow_type = converter.register_python_class(annotation) +if isinstance(arrow_type, pa.ExtensionType): + arrow_type = arrow_type.storage_type # strip top-level ext for struct field (ET1) +arrow_fields.append(pa.field(field.name, arrow_type)) +``` + +**`reconstruct_from_arrow`** (`DataclassLogicalTypeFactory` implementation): satisfies the +registration completeness invariant by calling `converter.register_python_class(annotation)` +for each field annotation — the same mechanism the write path already uses. The return value +is discarded; only the registration side effect is needed here. This is the implementation +choice for the dataclass factory; other factories may satisfy the invariant differently. + +```python +for field in dataclasses.fields(cls): + if not field.init: + continue + annotation = hints.get(field.name, Any) + converter.register_python_class(annotation) # ← NEW: registers nested types + field_annotations.append((field.name, annotation)) +``` + +Trigger chain on read path (for the dataclass factory): +``` +register_discovered_extensions + → converter.register_arrow_extension("mymod.Outer", ...) + → DataclassLogicalTypeFactory.reconstruct_from_arrow(...) + → converter.register_python_class(Inner) ← registers Inner + → DataclassLogicalTypeFactory.create_for_python_type(Inner, ...) +``` + +### 4. `extension_types/database_hooks.py` + +**No change.** `register_storage_type` still returns a meaningful `pa.DataType`, and +`database_hooks.py` already passes that resolved value into `register_arrow_extension`: + +```python +resolved_storage = converter.register_storage_type(info.storage_type) +converter.register_arrow_extension( + info.extension_name, + info.extension_metadata, + resolved_storage, +) +``` + +The only behavioral difference is that `resolved_storage` is now guaranteed to be +storage-safe (no nested extension types in struct/list fields), which is precisely what +`register_arrow_extension` needs. + +### 5. `DESIGN_ISSUES.md` + +Check whether the nested-dataclass read-path breakage is logged. If so, mark it resolved; +if not, it was an untracked bug — no new entry needed since the fix is delivered here. + +--- + +## Test changes + +### `tests/test_semantic_types/test_universal_converter.py` + +**`register_python_class` tests** (0 updates): the existing assertions check +`isinstance(result, pa.ExtensionType)` and `result.extension_name == "..."`. Under the new +storage-safe contract `register_python_class` still returns an extension type for registered +classes — these tests are already correct and need no changes. + +**`register_storage_type` tests** (1 update): only the test that currently asserts an +extension type is *preserved* inside a struct field needs to change. Under the new +storage-safe contract, that extension type must be stripped to its storage type before +being placed into the rebuilt struct. + +All other `register_storage_type` tests — including those that check the returned struct +or list shape — continue to pass with only the assertion on the inner field type updated. + +### `tests/test_extension_types/test_dataclass_logical_type_factory.py` + +- `test_register_python_class_dispatches_to_dataclass_factory`: **no change** — already + asserts `isinstance(result, pa.ExtensionType)` and `result.extension_name == "orcapod.uuid"`, + which is correct under the new storage-safe contract +- New test `test_reconstruct_from_arrow_registers_nested_types`: creates a two-level + dataclass hierarchy, calls `reconstruct_from_arrow` for the outer type only, then + asserts that the inner type is also present in the registry +- New test `test_nested_dataclass_parquet_roundtrip`: end-to-end Parquet round-trip for + a two-level dataclass (`_Inner` nested inside `_Outer`). Write path: build a converter, + register `_Outer`, write an Arrow table with an `_Outer` instance to a Parquet file. + Read path: create a **fresh converter** (only built-in types + `DataclassLogicalTypeFactory`, + neither `_Inner` nor `_Outer` pre-registered), read the Parquet file back, call + `register_discovered_extensions` on the schema — this should trigger the chain that + registers `_Outer` which in turn registers `_Inner`. Assert that converting the Arrow + struct storage back to a Python `_Outer` value produces the original object. + +--- + +## What does not change + +- `python_type_to_arrow_type` — still returns extension type +- `python_schema_to_arrow_schema` — already calls `python_type_to_arrow_type` (correct) +- `register_arrow_extension` — unchanged +- `extension_types/database_hooks.py` — unchanged (continues to use `register_storage_type` return value as before) +- All write-path value conversion (`python_to_storage`, `get_python_to_arrow_converter`) +- All read-path value conversion (`storage_to_python`, `get_arrow_to_python_converter`) +- `DataclassLogicalType` itself +- `apply_extension_types` / `database_hooks.apply_extension_types` +- All existing round-trip tests (behavior is unchanged; they continue to pass) diff --git a/superpowers/specs/2026-06-17-pydantic-logical-type-factory-design.md b/superpowers/specs/2026-06-17-pydantic-logical-type-factory-design.md new file mode 100644 index 00000000..1dab54d0 --- /dev/null +++ b/superpowers/specs/2026-06-17-pydantic-logical-type-factory-design.md @@ -0,0 +1,231 @@ +# Pydantic Logical Type Factory Design + +**Issue:** PLT-1731 +**Date:** 2026-06-17 +**Branch:** `eywalker/plt-1731-implement-pydantic-logical-type-factory-on-refined` + +--- + +## Overview + +Implement `PydanticLogicalType` and `PydanticLogicalTypeFactory` for pydantic v2 `BaseModel` +subclasses. The factory follows the same thin-leaf pattern established by +`DataclassLogicalTypeFactory` (PLT-1705): it synthesises one logical type per supported class, +delegates all field-type resolution to the converter via `register_python_class`, and stores +field annotations so that value conversion flows back through the converter at runtime. + +The two factories are fully independent — `PydanticLogicalTypeFactory` has no dependency on +`dataclass_logical_type_factory.py` and vice versa. + +--- + +## Goals & Success Criteria + +- `PydanticLogicalTypeFactory` implements `LogicalTypeFactoryProtocol` (write path + + read path). +- `PydanticLogicalType` implements `LogicalTypeProtocol`. +- For each model field, schema derivation and value conversion flow through the converter + re-entry points — no annotation traversal inside the factory. +- No coupling to `LogicalTypeRegistry` or cycle-detection internals. +- Pydantic is an optional dependency; the factory is importable and gracefully returns + `False` from `supports_class` when pydantic is not installed. +- All tests pass; a full Parquet round-trip test demonstrates end-to-end correctness. + +--- + +## Scope & Boundaries + +In scope: +- `PydanticLogicalType` and `PydanticLogicalTypeFactory` in a new + `src/orcapod/extension_types/pydantic_logical_type_factory.py`. +- Refactoring the FQCN walk loop into `type_utils._walk_fqcn` to avoid duplication. +- `pyproject.toml`: add `pydantic = ["pydantic>=2.0"]` optional extra; add to `all`. +- `extension_types/__init__.py`: export the new symbols. +- Test file `tests/test_extension_types/test_pydantic_logical_type_factory.py`. + +Out of scope: +- Wiring `PydanticLogicalTypeFactory` into the default `DataContext` / context JSON + (separate issue). +- Pydantic v1 support. +- Pydantic computed fields (`model_computed_fields`) — these are derived and not stored. +- Pydantic private attributes (`PrivateAttr`) — always have defaults; not stored. +- Pydantic model validators or field validators affecting storage values. +- Nested extension types inside list value fields (ET2 gap, tracked separately). + +--- + +## Architecture + +### New file: `pydantic_logical_type_factory.py` + +``` +src/orcapod/extension_types/pydantic_logical_type_factory.py +``` + +Contains: + +- `PYDANTIC_CATEGORY = "orcapod.pydantic"` — category tag embedded in Arrow extension + metadata; used as the factory dispatch key on the read path. +- `PydanticLogicalType` — logical type binding a pydantic `BaseModel` subclass to its + Arrow extension type representation. +- `PydanticLogicalTypeFactory` — stateless factory that synthesises and reconstructs + `PydanticLogicalType` instances. +- `_import_pydantic_model_from_fqcn(fqcn)` — private import helper; calls + `type_utils._walk_fqcn` then validates the resolved object is a `BaseModel` subclass. + +### `PydanticLogicalType` + +Constructor arguments: + +| Parameter | Type | Description | +|---|---|---| +| `logical_name` | `str` | FQCN; used as logical type name and Arrow extension name | +| `python_type` | `type` | The `BaseModel` subclass | +| `storage_type` | `pa.StructType` | Arrow struct of model fields | +| `field_annotations` | `list[tuple[str, Any]]` | Ordered `(field_name, annotation)` pairs | + +**`python_to_storage(value, converter)`** + +```python +{name: converter.python_to_storage(getattr(value, name), annotation) + for name, annotation in self._field_annotations} +``` + +**`storage_to_python(storage_value, converter)`** + +```python +kwargs = {name: converter.storage_to_python(storage_value[name], annotation) + for name, annotation in self._field_annotations} +return self._python_type(**kwargs) +``` + +Calling `python_type(**kwargs)` triggers full pydantic validation on reconstruction, +ensuring the model is always in a valid state. + +Arrow/Polars extension types are created via `make_arrow_extension_type` / +`make_polars_extension_type` with +`metadata = json.dumps({"category": PYDANTIC_CATEGORY}).encode("utf-8")`. + +Both conversion methods raise `ValueError` when `converter is None`. + +### `PydanticLogicalTypeFactory` + +**`supports_class(python_type)`** + +```python +try: + from pydantic import BaseModel +except ImportError: + return False +return isinstance(python_type, type) and issubclass(python_type, BaseModel) +``` + +Gracefully returns `False` if pydantic is not installed. The `try/except` is inside the +method rather than at module level so the factory module is importable regardless. + +**`create_for_python_type(python_type, converter)` — write path** + +1. Derive FQCN as `f"{python_type.__module__}.{python_type.__qualname__}"`. +2. Reject local classes (`""` in FQCN) with `ValueError`. +3. Call `typing.get_type_hints(python_type)` to resolve annotations (handles forward refs). +4. Iterate `python_type.model_fields` (pydantic v2 API) — this is the authoritative set of + stored fields. Computed fields and private attributes are automatically excluded. +5. For each field: `arrow_type = converter.register_python_class(annotation)`. Strip any + top-level `pa.ExtensionType` before inserting into the struct (ET1 constraint: struct + fields must never contain nested extension types). +6. Return `PydanticLogicalType(fqcn, python_type, pa.struct(arrow_fields), field_annotations)`. + +**`reconstruct_from_arrow(arrow_extension_name, storage_type, metadata, converter)` — read path** + +1. Validate `storage_type` is a struct; raise `ValueError` otherwise. +2. Import class from FQCN via `_import_pydantic_model_from_fqcn`; raises `ImportError` if + not found or not a `BaseModel` subclass. +3. Call `typing.get_type_hints(cls)` and iterate `cls.model_fields` to recover + `field_annotations`. +4. Call `converter.register_python_class(annotation)` per field — registration completeness + invariant: all nested logical types must be registered when the outer type is registered. +5. Return `PydanticLogicalType(arrow_extension_name, cls, storage_type, field_annotations)`. + +### FQCN import refactoring + +`type_utils._walk_fqcn(fqcn: str) -> Any` performs the module-prefix walk and attribute +chain traversal, returning the raw resolved object without type validation. Both +`dataclass_logical_type_factory._import_from_fqcn` and +`pydantic_logical_type_factory._import_pydantic_model_from_fqcn` call `_walk_fqcn` and +apply their own type validation on top. The ~25-line walk loop is written once. + +### Registration + +```python +from pydantic import BaseModel +from orcapod.extension_types.pydantic_logical_type_factory import ( + PydanticLogicalTypeFactory, PYDANTIC_CATEGORY +) +converter.register_logical_type_factory( + PydanticLogicalTypeFactory(), + category=PYDANTIC_CATEGORY, + python_bases=[BaseModel], +) +``` + +`python_bases=[BaseModel]` ensures MRO dispatch only probes this factory for classes that +actually inherit from `BaseModel`, rather than every class in the system. + +--- + +## Dependency changes + +**`pyproject.toml`:** + +```toml +[project.optional-dependencies] +# existing entries ... +pydantic = ["pydantic>=2.0"] +all = ["orcapod[redis]", "orcapod[ray]", "orcapod[postgresql]", "orcapod[spiraldb]", "orcapod[pydantic]"] +``` + +--- + +## Files changed + +| File | Change | +|---|---| +| `src/orcapod/extension_types/pydantic_logical_type_factory.py` | New — `PYDANTIC_CATEGORY`, `PydanticLogicalType`, `PydanticLogicalTypeFactory`, `_import_pydantic_model_from_fqcn` | +| `src/orcapod/extension_types/__init__.py` | Export `PYDANTIC_CATEGORY`, `PydanticLogicalType`, `PydanticLogicalTypeFactory` | +| `src/orcapod/extension_types/type_utils.py` | Add `_walk_fqcn` shared helper | +| `src/orcapod/extension_types/dataclass_logical_type_factory.py` | `_import_from_fqcn` delegates to `type_utils._walk_fqcn` | +| `pyproject.toml` | Add `pydantic` optional extra; add to `all` | +| `tests/test_extension_types/test_pydantic_logical_type_factory.py` | New — full test suite | + +--- + +## Test plan + +All module-level pydantic models used in tests that require FQCN reconstruction are +defined at module scope (not inside test functions), consistent with +`test_dataclass_logical_type_factory.py`. + +| Test | What it checks | +|---|---| +| `test_pydantic_logical_type_is_importable` | Module-level smoke test | +| `test_pydantic_logical_type_protocol_conformance` | `isinstance(lt, LogicalTypeProtocol)` | +| `test_pydantic_logical_type_python_to_storage` | `getattr`-based dict output | +| `test_pydantic_logical_type_storage_to_python` | `python_type(**kwargs)` reconstruction | +| `test_pydantic_logical_type_logical_type_name` | FQCN stored correctly | +| `test_pydantic_logical_type_python_type` | `.python_type` property | +| `test_factory_supports_class_pydantic_model` | Returns `True` for `BaseModel` subclass | +| `test_factory_supports_class_non_pydantic` | Returns `False` for `str`, `int`, plain dataclass | +| `test_factory_create_flat_model` | Arrow struct with correct field types | +| `test_factory_create_model_with_uuid_field` | UUID field stripped to `large_binary` in struct (ET1) | +| `test_factory_create_model_with_list_field` | `list[str]` → `pa.large_list(pa.large_string())` | +| `test_factory_create_model_with_dict_field` | `dict[str, int]` → `list[struct{key, value}]` | +| `test_factory_rejects_local_class` | `ValueError` with `"local"` in message | +| `test_factory_reconstruct_from_arrow` | Read path rebuilds correct `PydanticLogicalType` | +| `test_factory_reconstruct_from_arrow_invalid_fqcn` | `ImportError` on bad FQCN | +| `test_reconstruct_from_arrow_registers_nested_types` | Nested model registered as side effect | +| `test_pydantic_python_to_storage_round_trip` | `python_to_storage` → `storage_to_python` → equivalent model | +| `test_pydantic_with_uuid_round_trip` | UUID field survives round-trip | +| `test_python_to_storage_raises_when_converter_none` | `ValueError` guard | +| `test_storage_to_python_raises_when_converter_none` | `ValueError` guard | +| `test_nested_pydantic_model_parquet_roundtrip` | Full Parquet write → fresh-converter read | +| `test_private_fields_not_stored` | Model with `PrivateAttr` — private field absent from Arrow struct | diff --git a/superpowers/specs/2026-06-18-plt-1701-wire-factories-into-default-registry-design.md b/superpowers/specs/2026-06-18-plt-1701-wire-factories-into-default-registry-design.md new file mode 100644 index 00000000..e3a94aee --- /dev/null +++ b/superpowers/specs/2026-06-18-plt-1701-wire-factories-into-default-registry-design.md @@ -0,0 +1,182 @@ +# PLT-1701: Wire DataclassLogicalTypeFactory and PydanticLogicalTypeFactory into the Default LogicalTypeRegistry + +**Date:** 2026-06-18 +**Issue:** PLT-1701 +**Branch:** eywalker/plt-1701-wire-dataclasshandlerfactory-into-the-default + +--- + +## Overview + +`DataclassLogicalTypeFactory` and `PydanticLogicalTypeFactory` are fully implemented but must +be manually registered by users on their `LogicalTypeRegistry` instances. Until they are +wired into the default context, dataclass- and pydantic-annotated pod fields are not +auto-handled out of the box. + +Wiring them in requires one registry change: `LogicalTypeRegistry.__init__` must accept a +`factories` parameter so the JSON object-spec config can specify factory registrations +alongside the existing `logical_types` list. + +Pydantic is promoted to an explicit (non-optional) orcapod dependency. This removes all +graceful-import logic and makes missing pydantic a hard failure — both at context-load time +and inside the factory itself. + +--- + +## Goals & Success Criteria + +- `LogicalTypeRegistry.__init__` accepts an optional `factories` parameter: a list of dicts, + each with keys `factory` (instance), `category` (string), `python_bases` (list of `type`). +- `v0.1.json` wires in both factories under `logical_type_registry._config.factories`, using + `{"_type": "..."}` object-specs for `python_bases` — resolved by `parse_objectspec` at + context-load time exactly as other types are today. +- Pydantic is listed as a required dependency in `pyproject.toml`. +- `PydanticLogicalTypeFactory.supports_class` drops its `try/except ImportError` guard and + imports pydantic directly. +- The default context automatically handles dataclass- and pydantic-annotated pod fields + (write path) and reconstructs such columns from Parquet/Delta (read path) with zero + user-side setup. +- Existing tests pass. New tests explicitly verify factory registration and end-to-end use. + +--- + +## Scope & Boundaries + +**In scope:** +- `pyproject.toml` — pydantic added as explicit required dependency +- `LogicalTypeRegistry.__init__` `factories` parameter +- `PydanticLogicalTypeFactory.supports_class` — remove try/except, direct pydantic import +- `v0.1.json` update (both factories) +- Unit tests: registry `factories` param, pydantic import directness +- Integration tests: default context factory registration + converter end-to-end with + dataclass and pydantic types, Parquet round-trip + +**Out of scope:** +- String-FQCN support in `register_logical_type_factory` — not needed; `parse_objectspec` + resolves `{"_type": "pydantic.BaseModel"}` directly +- Changes to `DataclassLogicalTypeFactory` logic +- Changes to `parse_objectspec`, `contexts/core.py`, or `contexts/registry.py` +- Picklable or other factory types — those wire in separately +- `context_schema.json` — `_config` already uses `"additionalProperties": true` + +--- + +## Design + +### 1. `pyproject.toml` + +Add `pydantic>=2.0` (or the currently pinned version) to the `[project.dependencies]` list. +Remove any `[project.optional-dependencies]` entry for pydantic if one exists. + +### 2. `extension_types/pydantic_logical_type_factory.py` + +`supports_class` currently wraps its `from pydantic import BaseModel` in a `try/except +ImportError` that returns `False` when pydantic is absent. Drop the guard: + +```python +def supports_class(self, python_type: type) -> bool: + from pydantic import BaseModel + return isinstance(python_type, type) and issubclass(python_type, BaseModel) +``` + +No other changes to this file. + +### 3. `extension_types/registry.py` + +#### `__init__` — add `factories` parameter + +```python +def __init__( + self, + logical_types: list[LogicalTypeProtocol] | None = None, + factories: list[dict] | None = None, +) -> None: +``` + +After registering `logical_types`, iterate `factories` (if any). Each dict has: + +| Key | Type | Required | Description | +|---|---|---|---| +| `factory` | `LogicalTypeFactoryProtocol` | yes | Factory instance | +| `category` | `str` | no | Category key for read-path dispatch | +| `python_bases` | `list[type]` | no | Base classes for write-path dispatch | + +Call `self.register_logical_type_factory(factory, category=..., python_bases=...)` for each. + +No changes to `register_logical_type_factory` itself — it already accepts `Iterable[type]`. + +### 4. `contexts/data/v0.1.json` + +Add a `"factories"` list to `type_converter → _config → logical_type_registry → _config`: + +```json +"factories": [ + { + "factory": { + "_class": "orcapod.extension_types.dataclass_logical_type_factory.DataclassLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.dataclass", + "python_bases": [{"_type": "builtins.object"}] + }, + { + "factory": { + "_class": "orcapod.extension_types.pydantic_logical_type_factory.PydanticLogicalTypeFactory", + "_config": {} + }, + "category": "orcapod.pydantic", + "python_bases": [{"_type": "pydantic.BaseModel"}] + } +] +``` + +`parse_objectspec` resolves `{"_type": "builtins.object"}` → `object` and +`{"_type": "pydantic.BaseModel"}` → `BaseModel`. Both arrive in `__init__` as actual +`type` objects — no special handling needed in the registry. + +--- + +## Test Plan + +### `tests/test_extension_types/test_default_context_factories.py` (new file) + +**Registry constructor unit tests:** + +- `test_registry_factories_param_registers_category_factory` — construct + `LogicalTypeRegistry(factories=[{"factory": DataclassLogicalTypeFactory(), "category": "orcapod.dataclass", "python_bases": [object]}])` + and assert the category factory is accessible via `_category_factories`. +- `test_registry_factories_param_registers_python_base_factory` — same shape, verify + `_python_class_factories[object]` is set. +- `test_registry_factories_param_empty_list_is_noop` — `LogicalTypeRegistry(factories=[])` + succeeds without error. + +**Default context integration tests:** + +- `test_default_context_has_dataclass_factory_registered` — create a fresh registry via + `create_registry()` and verify `_category_factories["orcapod.dataclass"]` is an instance + of `DataclassLogicalTypeFactory`. +- `test_default_context_has_pydantic_factory_registered` — same for `"orcapod.pydantic"` / + `PydanticLogicalTypeFactory`. +- `test_default_context_dataclass_auto_registered_on_use` — call + `create_registry().get_context().type_converter.register_python_class(SomeModuleLevelDataclass)` + with no prior manual setup; verify the returned Arrow type is an extension type with the + correct extension name (the dataclass FQCN). +- `test_default_context_pydantic_model_auto_registered_on_use` — same for a pydantic + `BaseModel` subclass. +- `test_default_context_dataclass_parquet_roundtrip` — full end-to-end: write a dataclass + column via a fresh default-context converter to Parquet; read it back with another fresh + default-context converter using `register_discovered_extensions` + `apply_extension_types`; + verify the reconstructed Python object matches the original, with no manual factory + registration calls anywhere in the test. + +**Note on context freshness:** All integration tests use `create_registry().get_context()` +rather than `get_default_context()` to avoid cross-test contamination via the global +singleton cache. + +--- + +## Dependencies + +- `DataclassLogicalTypeFactory` (PLT-1705, already on `extension-type-system`) +- `PydanticLogicalTypeFactory` (PLT-1731, already on `extension-type-system`) +- `parse_objectspec` already handles `{"_type": "..."}` → no changes needed there diff --git a/superpowers/specs/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests-design.md b/superpowers/specs/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests-design.md new file mode 100644 index 00000000..d44a3ff9 --- /dev/null +++ b/superpowers/specs/2026-06-23-plt-1659-extension-type-roundtrip-integration-tests-design.md @@ -0,0 +1,228 @@ +# PLT-1659: End-to-End Extension Type Round-Trip Integration Tests — Design Spec + +**Date:** 2026-06-23 +**Linear issue:** PLT-1659 +**Branch:** `eywalker/plt-1659-integration-tests-end-to-end-semantic-type-round-trips` +**PR target:** `extension-type-system` + +--- + +## Overview + +This spec covers the design of end-to-end integration tests for the Arrow/Polars extension type +system introduced in the `extension-type-system` branch. The tests validate the complete pipeline: + +``` +Python object → write → storage → peek-schema → register → read → Python object +``` + +These are *integration* tests only. Existing unit tests in `tests/test_extension_types/` (registry, +schema walker, database hooks, built-in logical types, protocols) are not duplicated. + +--- + +## What Is Tested + +### Built-in types: `Path`, `UPath`, `UUID` + +Round-trip through two storage backends (Parquet and Delta — SQLite excluded, see +`test_roundtrips.py` note). Assertions: +- Python object is faithfully reconstructed after read. +- Arrow extension names are in the `orcapod.*` namespace (`orcapod.path`, `orcapod.upath`, + `orcapod.uuid`). + +### Dataclass types + +- **Simple dataclass** (scalar fields only): write → read → verify field values. +- **Two dataclasses with identical struct shape, different class names** (`_PointA` vs `_PointB`): + verify they are stored and recovered as distinct extension types (distinct Arrow extension names). +- **Nested dataclass** (outer contains inner as a field): write → read → verify recursive + reconstruction; assert both inner and outer types are registered after the read. + +### Delta Lake direct read + +Write a dataclass column to Delta Lake. Read back via `pl.read_delta` (Polars native Delta +reader). Assert the column dtype carries the correct extension type. + +### Schema compatibility + +Two sub-areas: + +- **Arrow-level identity**: `converter.python_schema_to_arrow_schema` for `_PointA` and `_PointB` + produces distinct Arrow extension names, even though the underlying struct shapes are identical. +- **Python-type-level compatibility**: `check_schema_compatibility` from `schema_utils` correctly + passes when types match and rejects when the same-shaped-but-different-named types are used. + +### Per-process cache behavior + +- **Cache populated on read**: fresh converter + Parquet file containing a registered dataclass → + after `converter.load_extension_types(...)`, the type is present in the registry. +- **Factory skipped on second read**: patching `factory.reconstruct_from_arrow` confirms it is + called exactly once on first read and zero times on second read (registry hit short-circuits + factory dispatch). + +--- + +## What Is Explicitly Out of Scope + +| Excluded | Reason | Tracked in | +|---|---|---| +| `list[MyDataclass]` round-trip | Known limitation (ET2); requires `ListLogicalType` infrastructure | PLT-1732 | +| Picklable types | `PicklableLogicalTypeFactory` (PLT-1658) not yet implemented | PLT-1658 | +| Pydantic round-trips | Already covered in `test_default_context_factories.py` | — | +| Duplicate unit tests | Existing unit tests in `test_extension_types/` are not repeated | — | + +--- + +## File Organisation + +Three new files, all in `tests/test_extension_types/`: + +``` +tests/test_extension_types/ +├── test_roundtrips.py # Write/read round-trips across backends +├── test_schema_compatibility.py # Arrow-level + Python-type-level compatibility +└── test_cache_behavior.py # Per-process cache: populated / skipped on second read +``` + +--- + +## Backend Parameterisation + +`test_roundtrips.py` parameterises over **two** storage backends via a `_StorageBackend` +dataclass with two callables. SQLite (`ConnectorArrowDatabase` + `SQLiteConnector`) is +excluded because `SQLiteConnector` discards `ARROW:extension:*` field metadata during type +mapping — see `DESIGN_ISSUES.md` CA1 and PLT-1795. + +```python +@dataclasses.dataclass +class _StorageBackend: + name: str + write: Callable[[pa.Table, Path], None] + read: Callable[[Path, UniversalTypeConverter], pa.Table] +``` + +| `name` | `write` | `read` | +|---|---|---| +| `"parquet"` | `pq.write_table(table, path / "data.parquet")` | `converter.load_extension_types(pq.read_table(path / "data.parquet"))` | +| `"delta"` | `deltalake.write_deltalake(str(path / "delta"), table)` | `converter.load_extension_types(DeltaTable(str(path / "delta")).to_pyarrow_dataset(as_large_types=True).to_table())` | + +`as_large_types=True` is required for the Delta backend: without it, Delta Lake normalises +`large_string` → `string` and `large_binary` → `binary`, which causes the extension type +deserializer to reject the storage type mismatch. + +The `read` callable always returns a `pa.Table` containing only the original user data columns. + +A `@pytest.fixture(params=[...])` named `storage_backend` yields one `_StorageBackend` per run. + +--- + +## Module-Level Test Fixtures + +All test dataclasses must be defined at module level — `DataclassLogicalTypeFactory` rejects +local classes because they have no stable FQCN for reconstruction on read. + +```python +# test_roundtrips.py, test_schema_compatibility.py, and test_cache_behavior.py +# Each file defines its own module-level dataclasses — no sharing across files. +@dataclasses.dataclass +class _PointA: + x: int + y: int + +@dataclasses.dataclass +class _PointB: # same shape as _PointA, different class name + x: int + y: int + +@dataclasses.dataclass +class _Inner: + value: int + +@dataclasses.dataclass +class _Outer: + inner: _Inner + label: str +``` + +Each test creates its own converter via `create_registry().get_context().type_converter` (not +`get_default_context()`) to prevent cross-test contamination through the global singleton cache. + +--- + +## Test Descriptions + +### `test_roundtrips.py` + +#### Parameterised over both backends + +**`test_builtin_path_round_trip[backend]`** +Write a `Path` column, read back, assert `pathlib.Path` values are reconstructed and the Arrow +field extension name is `"orcapod.path"`. + +**`test_builtin_upath_round_trip[backend]`** +Same for `UPath` / `"orcapod.upath"`. + +**`test_builtin_uuid_round_trip[backend]`** +Same for `uuid.UUID` / `"orcapod.uuid"`. + +**`test_simple_dataclass_round_trip[backend]`** +Write a `_PointA` column, read back, assert field values match and the Arrow field is an +`pa.ExtensionType` with extension name equal to the FQCN of `_PointA`. + +**`test_nested_dataclass_round_trip[backend]`** +Write an `_Outer` column. Read back. Assert: +- `_Outer` and `_Inner` are both in the registry after read. +- Reconstructed value is an `_Outer` with an `_Inner` field; all values correct. + +#### Delta Lake only + +**`test_delta_polars_read_delta`** +Write a `_PointA` column to Delta via `deltalake.write_deltalake`. Read back via +`pl.read_delta(str(delta_path))`. Assert the resulting Polars DataFrame column has dtype +that is a Polars extension type (i.e. the extension type survived the Delta round-trip). + +### `test_schema_compatibility.py` + +**`test_arrow_schema_distinct_extension_names_for_same_shape`** +Register `_PointA` and `_PointB` with a fresh converter. Assert: +```python +schema_a.field("value").type.extension_name != schema_b.field("value").type.extension_name +``` + +**`test_arrow_schema_same_extension_name_idempotent`** +Register `_PointA` twice. Assert the extension name is the same both times. + +**`test_python_schema_compatibility_passes_same_type`** +`check_schema_compatibility({"value": _PointA}, Schema({"value": _PointA}))` → `True`. + +**`test_python_schema_compatibility_rejects_different_type_same_shape`** +`check_schema_compatibility({"value": _PointA}, Schema({"value": _PointB}))` → `False`. +This is the core guarantee: the extension type system prevents same-shape-different-class +confusion that would have been silently accepted by the old shape-based system. + +### `test_cache_behavior.py` + +**`test_cache_populated_after_first_read`** +1. Write a Parquet file with a `_PointA` column (fresh converter, type registered for write). +2. Create a second fresh converter (type *not* pre-registered). +3. Call `read_converter.load_extension_types(pq.read_table(path))`. +4. Assert `read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn)` is not `None`. + +**`test_factory_not_called_on_second_read`** +1. Write Parquet as above. +2. Fresh converter. Patch `DataclassLogicalTypeFactory.reconstruct_from_arrow` with a spy. +3. First `load_extension_types` call → spy called exactly once. +4. Second `load_extension_types` call on the same file → spy call count unchanged (registry hit). + +--- + +## Key Implementation Notes + +- Use `uv run pytest` (never bare `pytest`) per CLAUDE.md. +- No `POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR` env var needed — tests rely on registration. +- All tests use `tmp_path` (pytest built-in) for temp dirs; no external cluster required. +- SQLite backend uses `SQLiteConnector(str(tmp_path / "db.sqlite"))` — not `:memory:`, because + the `ConnectorArrowDatabase` instance is recreated between write and read to simulate + the separate-process scenario. +- Delta backend requires `deltalake` package (already a project dependency). diff --git a/superpowers/specs/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md b/superpowers/specs/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md new file mode 100644 index 00000000..f6f5b009 --- /dev/null +++ b/superpowers/specs/2026-06-24-plt-1660-hard-cut-extension-type-hashing.md @@ -0,0 +1,628 @@ +# PLT-1660: Hard cut — delete old semantic type system and wire in extension type system + +**Date:** 2026-06-24 +**Issue:** PLT-1660 +**Branch:** `eywalker/plt-1660-hard-cut-delete-old-semantic-type-system-and-wire-in` +**Target:** `extension-type-system` + +--- + +## Overview + +The codebase currently has two parallel "semantic type" systems: + +1. **Old system** (shape-based identity): `SemanticTypeRegistry` / `SemanticStructConverterProtocol` — identifies + extension types by matching Arrow struct field signatures. Lives in `src/orcapod/semantic_types/`. +2. **New system** (extension type identity): `LogicalTypeRegistry` / `LogicalTypeProtocol` — identifies types by + `ARROW:extension:name` metadata embedded in the Arrow field. Lives in `src/orcapod/extension_types/`. + +`UniversalTypeConverter` already uses only the new system. This issue performs a "hard cut": delete the old +system entirely and wire the new system into the remaining production call sites — primarily the Arrow hashing +visitors. + +This issue also folds in a protocol tightening: `TypeHandlerProtocol.handle()` currently has a mixed return +type (`Any`) — some handlers return `ContentHash` directly (Path, ArrowTable), while others return intermediate +values (UUID returns `bytes`, BytesHandler returns `str`, etc.). Since all handlers receive the full hasher +reference and the only purpose of a handler is to produce a hash, the protocol is tightened so every handler +returns `ContentHash` directly. This makes the naming accurate and the interface uniform. + +--- + +## Scope + +### In scope +- Rewrite `SemanticHashingVisitor` in `visitors.py` to dispatch on extension types instead of struct signatures +- Update `StarfixArrowHasher` (and delete `SemanticArrowHasher`) to accept `type_converter + semantic_hasher` + instead of `semantic_registry` +- **Protocol tightening**: change `TypeHandlerProtocol.handle() -> Any` to + `PythonTypeSemanticHasherProtocol.hash() -> ContentHash`; update all builtin handlers accordingly +- **Renames** (full list in §Design §5): + - `BaseSemanticHasher` → `SemanticAwarePythonHasher` + - `TypeHandlerRegistry` → `PythonTypeSemanticHasherRegistry` + - `BuiltinTypeHandlerRegistry` → `BuiltinPythonTypeSemanticHasherRegistry` + - `TypeHandlerProtocol` → `PythonTypeSemanticHasherProtocol` + - All builtin handler classes renamed (e.g. `PathContentHandler` → `PathSemanticHasher`) + - `register_builtin_handlers` → `register_builtin_python_type_semantic_hashers` + - `get_default_type_handler_registry` → `get_default_python_type_semantic_hasher_registry` +- Update `v0.1.json` to remove `semantic_registry` component and update all class names / cross-refs +- Update `context_schema.json` to match +- Delete `semantic_struct_converters.py`, `semantic_registry.py`, `SemanticStructConverterProtocol`, and + `tests/test_semantic_types/` +- Update all imports and references across the codebase + +### Out of scope +- PLT-1798 (making `extension_name == logical_type_name` invariant explicit in code) +- Any changes to `UniversalTypeConverter` — already fully migrated + +--- + +## Design + +### 1. Extension-type dispatch in `ArrowTypeDataVisitor` + +**File:** `src/orcapod/hashing/visitors.py` + +Add `visit_extension` as a non-abstract method on the base class. Update `visit()` to check +`isinstance(arrow_type, pa.ExtensionType)` **before** the struct check — otherwise extension types with +struct storage would be swallowed by `visit_struct`. + +```python +def visit_extension( + self, extension_type: "pa.ExtensionType", storage_value: Any +) -> tuple["pa.DataType", Any]: + """Handle an Arrow extension type. + + Default implementation: passthrough — preserves the extension type and its storage + value unchanged so that the downstream StarfixArrowHasher / ArrowDigester sees the + full extension metadata when it receives the pre-processed table. + + Subclasses may override to convert recognised extension types to a hashed + pa.large_binary() value. + """ + return extension_type, storage_value + +def visit(self, arrow_type: "pa.DataType", data: Any) -> tuple["pa.DataType", Any]: + # Extension types must be checked FIRST. A Path column has storage type + # large_string, and its field type is an ExtensionType wrapping that storage. + # Checking is_struct first would incorrectly route extension types with struct + # storage into visit_struct. + if isinstance(arrow_type, pa.ExtensionType): + new_type, new_data = self.visit_extension(arrow_type, data) + # Re-visit if visit_extension transformed to a non-extension type. + # This enables composability (e.g. a list-of-extension-type handler returning + # pa.large_list(pa.large_binary())) and avoids infinite recursion: we only + # re-enter when the type changed AND is no longer an extension type. + if new_type is not arrow_type and not isinstance(new_type, pa.ExtensionType): + return self.visit(new_type, new_data) + return new_type, new_data + if pa.types.is_struct(arrow_type): + return self.visit_struct(arrow_type, data) + elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): + return self.visit_list(arrow_type, data) + elif pa.types.is_fixed_size_list(arrow_type): + return self.visit_list(arrow_type, data) + elif pa.types.is_map(arrow_type): + return self.visit_map(arrow_type, data) + else: + return self.visit_primitive(arrow_type, data) +``` + +### 2. `SemanticHashingVisitor` rewrite + +**File:** `src/orcapod/hashing/visitors.py` + +Constructor changes from `(semantic_registry: SemanticTypeRegistry)` to +`(type_converter: UniversalTypeConverter, python_hasher: SemanticAwarePythonHasher)`. + +Core logic moves from `visit_struct` into `visit_extension`: + +```python +class SemanticHashingVisitor(ArrowTypeDataVisitor): + """Visitor that replaces extension-typed columns with their content hashes. + + For each Arrow column whose type is a ``pa.ExtensionType``: + 1. Look up the corresponding Python type via ``type_converter``. + 2. If the Python type has a semantic hasher registered in ``python_hasher``, + convert the storage value to a Python object and hash it, replacing the + column with a ``pa.large_binary()`` value of the form:: + + type_name_bytes + b"::" + content_hash.to_prefixed_digest() + + where ``type_name`` is the extension name with dots replaced by colons + (e.g. ``"orcapod.path"`` → ``"orcapod:path"``), and + ``content_hash.to_prefixed_digest()`` = ``method_bytes + b":" + digest``. + The ``::`` separator is unambiguous because ``to_prefixed_digest()`` only + uses single ``:``. Splitting on ``b"::"`` recovers both parts cleanly. + 3. If no hasher is registered (or if ``type_converter`` does not know the + extension type), return the extension type and storage value unchanged. + The downstream ``StarfixArrowHasher`` / ``ArrowDigester`` will see the + full extension metadata intact and hash it in a type-aware way. + """ + + def __init__( + self, + type_converter: "UniversalTypeConverter", + python_hasher: "SemanticAwarePythonHasher", + ) -> None: + self._type_converter = type_converter + self._python_hasher = python_hasher + self._current_field_path: list[str] = [] + + def visit_extension( + self, extension_type: "pa.ExtensionType", storage_value: Any + ) -> tuple["pa.DataType", Any]: + if storage_value is None: + return extension_type, None + + # Resolve extension type → Python type. + python_type = self._type_converter.arrow_type_to_python_type(extension_type) + + # If the converter couldn't resolve to a concrete class, passthrough. + if python_type is Any or not isinstance(python_type, type): + return extension_type, storage_value + + # Only hash if the python hasher has a semantic hasher for this type. + if not self._python_hasher.type_semantic_hasher_registry.has_semantic_hasher(python_type): + return extension_type, storage_value + + # Convert storage value → Python object and hash it. + python_obj = self._type_converter.storage_to_python(storage_value, python_type) + content_hash = self._python_hasher.hash_object(python_obj) + + # Encode as binary: ":::" + # Dots in the extension name are replaced with colons so the type prefix + # uses a consistent namespace separator (e.g. "orcapod:path"). + # The "::" separator is unambiguous — to_prefixed_digest() only uses ":". + type_name = extension_type.extension_name.replace(".", ":") + hash_bytes = ( + type_name.encode("ascii") + + b"::" + + content_hash.to_prefixed_digest() + ) + return pa.large_binary(), hash_bytes + + def visit_struct(self, struct_type, data): + """Regular struct (no extension identity) — recurse into fields.""" + if data is None: + return struct_type, None + return self._visit_struct_fields(struct_type, data) + + def visit_list(self, list_type, data): + if data is None: + return list_type, None + self._current_field_path.append("[*]") + try: + return self._visit_list_elements(list_type, data) + finally: + self._current_field_path.pop() + + def visit_map(self, map_type, data): + return map_type, data + + def visit_primitive(self, primitive_type, data): + return primitive_type, data +``` + +**Passthrough invariant:** when `visit_extension` returns the original `(extension_type, storage_value)`, +the column's field type remains a `pa.ExtensionType`. `schema_cleaner.clean_schema_for_hashing` retains +all `ARROW:extension:*` metadata, so `ArrowDigester.hash_table(..., include_metadata=True)` sees the full +extension identity. Extension types without a registered Python semantic hasher are still hashed in a +type-aware way by the underlying starfix algorithm. + +### 3. `StarfixArrowHasher` constructor update + +**File:** `src/orcapod/hashing/arrow_hashers.py` + +```python +# Before +def __init__(self, semantic_registry: SemanticTypeRegistry, hasher_id: str) -> None: + +# After +def __init__( + self, + type_converter: "UniversalTypeConverter", + semantic_hasher: "SemanticAwarePythonHasher", + hasher_id: str, +) -> None: + self._type_converter = type_converter + self._semantic_hasher = semantic_hasher + self._hasher_id = hasher_id +``` + +`_process_table_columns` constructs `SemanticHashingVisitor(self._type_converter, self._semantic_hasher)` +instead of `SemanticHashingVisitor(self.semantic_registry)`. + +The short-circuit in `_process_table_columns` that skips non-struct/non-list columns must also allow +extension type columns through — otherwise Path columns (storage: `large_string`) would be silently skipped +before the visitor sees them: + +```python +if not ( + isinstance(field.type, pa.ExtensionType) # ← add this + or pa.types.is_struct(field.type) + or pa.types.is_list(field.type) + or pa.types.is_large_list(field.type) + or pa.types.is_fixed_size_list(field.type) + or pa.types.is_map(field.type) +): + new_columns.append(table.column(i)) + new_fields.append(field) + continue +``` + +### 4. `SemanticArrowHasher` (legacy hasher) + +**File:** `src/orcapod/hashing/arrow_hashers.py` + +`SemanticArrowHasher` predates `StarfixArrowHasher` and is not referenced in `v0.1.json`. **Delete** it as +part of the hard cut. If any test depends on it directly, delete the test — these tests are superseded by the +extension type integration tests. + +### 5. Renames + +#### Classes and protocols + +| Old name | New name | File | +|----------|----------|------| +| `BaseSemanticHasher` | `SemanticAwarePythonHasher` | `semantic_hashing/semantic_hasher.py` | +| `TypeHandlerRegistry` | `PythonTypeSemanticHasherRegistry` | `semantic_hashing/type_handler_registry.py` | +| `BuiltinTypeHandlerRegistry` | `BuiltinPythonTypeSemanticHasherRegistry` | `semantic_hashing/type_handler_registry.py` | +| `TypeHandlerProtocol` | `PythonTypeSemanticHasherProtocol` | `protocols/hashing_protocols.py` | + +#### Builtin handler classes (in `semantic_hashing/builtin_handlers.py`) + +| Old name | New name | +|----------|----------| +| `PathContentHandler` | `PathSemanticHasher` | +| `UPathContentHandler` | `UPathSemanticHasher` | +| `UUIDHandler` | `UUIDSemanticHasher` | +| `BytesHandler` | `BytesSemanticHasher` | +| `FunctionHandler` | `FunctionSemanticHasher` | +| `TypeObjectHandler` | `TypeObjectSemanticHasher` | +| `SpecialFormHandler` | `SpecialFormSemanticHasher` | +| `GenericAliasHandler` | `GenericAliasSemanticHasher` | +| `UnionTypeHandler` | `UnionTypeSemanticHasher` | +| `ArrowTableHandler` | `ArrowTableSemanticHasher` | +| `SchemaHandler` | `SchemaSemanticHasher` | + +#### Functions and properties + +| Old name | New name | Location | +|----------|----------|----------| +| `register_builtin_handlers(registry)` | `register_builtin_python_type_semantic_hashers(registry)` | `builtin_handlers.py` | +| `get_default_type_handler_registry()` | `get_default_python_type_semantic_hasher_registry()` | `type_handler_registry.py` and `defaults.py` | +| `BaseSemanticHasher.type_handler_registry` property | `SemanticAwarePythonHasher.type_semantic_hasher_registry` | `semantic_hasher.py` | + +#### Registry methods + +| Old name | New name | +|----------|----------| +| `get_handler(obj)` | `get_semantic_hasher(obj)` | +| `get_handler_for_type(target_type)` | `get_semantic_hasher_for_type(target_type)` | +| `has_handler(target_type)` | `has_semantic_hasher(target_type)` | + +The `register(target_type, handler)` method name is unchanged — "register" is generic enough. + +All references across the codebase (imports, JSON specs, tests, docs) must be updated in the same PR. +Per the project's no-backward-compatibility policy: no re-export aliases or deprecation wrappers. + +### 6. Protocol tightening — `PythonTypeSemanticHasherProtocol` + +**File:** `src/orcapod/protocols/hashing_protocols.py` + +The `handle(obj, hasher) -> Any` method is replaced by `hash(obj, hasher) -> ContentHash`: + +```python +class PythonTypeSemanticHasherProtocol(Protocol): + """Protocol for type-specific semantic hashers used by SemanticAwarePythonHasher. + + A PythonTypeSemanticHasherProtocol hashes a specific Python type to a ContentHash. + Implementations are registered with a PythonTypeSemanticHasherRegistry and looked + up via MRO-aware resolution. + + Each implementation receives the full SemanticAwarePythonHasher so it can delegate + hashing of sub-values (e.g. hashing a dict of function metadata) back to the outer + hasher without coupling to a specific hasher instance. + """ + + def hash(self, obj: Any, hasher: "SemanticAwarePythonHasher") -> ContentHash: + """Hash *obj* to a ContentHash. + + Args: + obj: The object to hash. Always matches the registered type. + hasher: The active SemanticAwarePythonHasher. Use + ``hasher.hash_object(sub_value)`` to hash sub-values. + + Returns: + ContentHash: The content-addressed hash of *obj*. + """ + ... +``` + +#### `hash_object()` simplification + +Because every semantic hasher now returns `ContentHash` directly, the dispatch in `hash_object()` simplifies +from a double call to a single call: + +```python +# Before +semantic_hasher = self._registry.get_semantic_hasher(obj) +if semantic_hasher is not None: + return self.hash_object(semantic_hasher.handle(obj, self), resolver=resolver) + # ^^^ recursive wrap ^^^ + +# After +semantic_hasher = self._registry.get_semantic_hasher(obj) +if semantic_hasher is not None: + return semantic_hasher.hash(obj, self) # always ContentHash — no wrap +``` + +#### Updated builtin implementations + +Each builtin class returns `ContentHash` directly by delegating sub-values back to `hasher.hash_object()`: + +```python +class PathSemanticHasher: + def __init__(self, file_hasher: FileContentHasherProtocol) -> None: + self.file_hasher = file_hasher + + def hash(self, obj: PathLike, hasher: SemanticAwarePythonHasher) -> ContentHash: + path = Path(obj) + # (existence / is_dir checks unchanged) + return self.file_hasher.hash_file(path) # already returns ContentHash + + +class UUIDSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + return hasher.hash_object(obj.bytes) # bytes → ContentHash via hasher + + +class BytesSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + if isinstance(obj, (bytes, bytearray)): + return hasher.hash_object(obj.hex()) # hex str → ContentHash via hasher + raise TypeError(...) + + +class FunctionSemanticHasher: + def __init__(self, function_info_extractor: Any) -> None: + self.function_info_extractor = function_info_extractor + + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + info = self.function_info_extractor.extract_function_info(obj) + return hasher.hash_object(info) # dict → ContentHash via hasher + + +class TypeObjectSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + module = obj.__module__ or "" + return hasher.hash_object(f"type:{module}.{obj.__qualname__}") + + +class ArrowTableSemanticHasher: + def __init__(self, arrow_hasher: ArrowHasherProtocol) -> None: + self.arrow_hasher = arrow_hasher + + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + if isinstance(obj, pa.RecordBatch): + obj = pa.Table.from_batches([obj]) + return self.arrow_hasher.hash_table(obj) # already returns ContentHash + + +class SpecialFormSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + name = getattr(obj, "_name", None) or repr(obj) + return hasher.hash_object(f"special_form:typing.{name}") + + +class GenericAliasSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + import typing + origin = getattr(obj, "__origin__", None) + args = getattr(obj, "__args__", None) or () + if origin is None: + return hasher.hash_object(f"generic_alias:{obj!r}") + if origin is typing.Union: + hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) + return hasher.hash_object({"__type__": "union", "args": hashed_args}) + return hasher.hash_object({ + "__type__": "generic_alias", + "origin": hasher.hash_object(origin).to_string(), + "args": [hasher.hash_object(arg).to_string() for arg in args], + }) + + +class UnionTypeSemanticHasher: + def hash(self, obj: Any, hasher: SemanticAwarePythonHasher) -> ContentHash: + args = getattr(obj, "__args__", None) or () + hashed_args = sorted(hasher.hash_object(arg).to_string() for arg in args) + return hasher.hash_object({"__type__": "union", "args": hashed_args}) +``` + +### 7. `v0.1.json` changes + +**File:** `src/orcapod/contexts/data/v0.1.json` + +- Remove the `semantic_registry` top-level component entirely. +- In `arrow_hasher._config`, replace: + ```json + "semantic_registry": {"_ref": "semantic_registry"} + ``` + with: + ```json + "type_converter": {"_ref": "type_converter"}, + "semantic_hasher": {"_ref": "semantic_hasher"} + ``` +- Rename component key `type_handler_registry` → `python_type_semantic_hasher_registry`. +- Update `semantic_hasher._config` ref: + ```json + "type_handler_registry": {"_ref": "python_type_semantic_hasher_registry"} + ``` +- Update `semantic_hasher._class`: + `orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher` + → `orcapod.hashing.semantic_hashing.semantic_hasher.SemanticAwarePythonHasher` +- Update `python_type_semantic_hasher_registry._class`: + `orcapod.hashing.semantic_hashing.type_handler_registry.TypeHandlerRegistry` + → `orcapod.hashing.semantic_hashing.type_handler_registry.PythonTypeSemanticHasherRegistry` +- Update all handler `_class` entries in `python_type_semantic_hasher_registry._config.handlers` + to use the new class names (e.g. `PathContentHandler` → `PathSemanticHasher`, etc.) + +Full updated component list in file order: +``` +file_hasher (unchanged) +semantic_registry ← DELETE +arrow_hasher (class unchanged; _config: + type_converter ref, + semantic_hasher ref, - semantic_registry ref) +type_converter (unchanged) +function_info_extractor (unchanged) +python_type_semantic_hasher_registry ← renamed from type_handler_registry; class + handler entries updated +semantic_hasher (class → SemanticAwarePythonHasher; ref updated) +``` + +### 8. `context_schema.json` changes + +**File:** `src/orcapod/contexts/data/schemas/context_schema.json` + +- Remove the `semantic_registry` property from `properties`. +- Rename `type_handler_registry` property to `python_type_semantic_hasher_registry`. + +### 9. `DataContext` core + +**File:** `src/orcapod/contexts/core.py` + +`DataContext` is a dataclass with `type_converter`, `arrow_hasher`, and `semantic_hasher` fields. +`type_handler_registry` is not a field on `DataContext` — it is an implementation detail of `semantic_hasher`. +No changes needed to `core.py`. + +### 10. `versioned_hashers.py` + +**File:** `src/orcapod/hashing/versioned_hashers.py` + +Update `get_versioned_semantic_arrow_hasher()`: +- Remove inline `SemanticTypeRegistry` / `PythonPathStructConverter` / `UUIDStructConverter` construction. +- Source `type_converter` and `semantic_hasher` from the default `DataContext`: + +```python +def get_versioned_semantic_arrow_hasher( + hasher_id: str = _CURRENT_ARROW_HASHER_ID, +) -> hp.ArrowHasherProtocol: + from orcapod.hashing.arrow_hashers import StarfixArrowHasher + from orcapod.contexts import resolve_context + + ctx = resolve_context(None) # default context + return StarfixArrowHasher( + hasher_id=hasher_id, + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, + ) +``` + +Update `get_versioned_semantic_hasher()` to import `SemanticAwarePythonHasher` instead of `BaseSemanticHasher`. + +--- + +## Files to delete + +| File | Reason | +|------|--------| +| `src/orcapod/semantic_types/semantic_struct_converters.py` | Old shape-based converters | +| `src/orcapod/semantic_types/semantic_registry.py` | Old `SemanticTypeRegistry` | +| `SemanticStructConverterProtocol` class in `src/orcapod/protocols/semantic_types_protocols.py` | Protocol for old system | +| `tests/test_semantic_types/` (all 9 files) | Tests for old system | + +After deletion, verify `src/orcapod/semantic_types/__init__.py` no longer re-exports deleted names. + +--- + +## Files to update (beyond the core changes) + +These files import from the deleted or renamed modules and must be updated: + +- `src/orcapod/hashing/__init__.py` — re-exports `BaseSemanticHasher`, `TypeHandlerRegistry`, `TypeHandlerProtocol` +- `src/orcapod/hashing/semantic_hashing/__init__.py` — re-exports all renamed classes +- `src/orcapod/hashing/defaults.py` — `get_default_type_handler_registry` → `get_default_python_type_semantic_hasher_registry` +- `src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py` — references `BaseSemanticHasher` +- `src/orcapod/hashing/versioned_hashers.py` — inline registry construction, old class names +- `src/orcapod/protocols/hashing_protocols.py` — `TypeHandlerProtocol` docstring references +- `src/orcapod/contexts/core.py` — `TYPE_CHECKING` import of `BaseSemanticHasher` (if any) +- `tests/test_hashing/` — update imports and any direct registry/handler references + +Run this sweep after implementation to catch any remaining references: + +```bash +grep -rn "SemanticTypeRegistry\|semantic_registry\|SemanticStructConverter\ +\|BaseSemanticHasher\|TypeHandlerRegistry\|BuiltinTypeHandlerRegistry\ +\|TypeHandlerProtocol\|PathContentHandler\|UPathContentHandler\ +\|UUIDHandler\|BytesHandler\|FunctionHandler\|TypeObjectHandler\ +\|SpecialFormHandler\|GenericAliasHandler\|UnionTypeHandler\|ArrowTableHandler\ +\|SchemaHandler\|register_builtin_handlers\|get_default_type_handler_registry\ +\|type_handler_registry\|get_handler\|has_handler" src/ tests/ +``` + +--- + +## Binary encoding format + +Hash values produced by `visit_extension` are stored as `pa.large_binary()` with the layout: + +``` + "::" +``` + +where: +- `type_name` = `extension_type.extension_name.replace(".", ":")` — dots in the Arrow extension + name are replaced with colons so the prefix uses a uniform namespace separator + (e.g. `"orcapod.path"` → `"orcapod:path"`, `"my.module.MyClass"` → `"my:module:MyClass"`) +- `"::"` is the separator between type prefix and hash — unambiguous because + `to_prefixed_digest()` only uses single `":"` +- `content_hash.to_prefixed_digest()` = `method.encode("ascii") + b":" + digest_bytes` + +Full example for a `pathlib.Path` column whose file is hashed by the semantic hasher: +``` +b"orcapod:path::semantic_v0.1:\xab\xcd\xef..." + ^^^^^^^^^^^ ^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^ + type prefix hasher_id raw SHA-256 digest + (dots→colons) +``` + +Parsing: `value.split(b"::", 1)` → `(b"orcapod:path", b"semantic_v0.1:\xab...")`. + +This is consistent with the existing pattern in `function_node.py`: +```python +self.data_context.arrow_hasher.hash_table(tag_with_hash).to_prefixed_digest() +``` + +--- + +## Test strategy + +1. Existing tests in `tests/test_hashing/` must all pass after renames, protocol changes, and wiring. +2. `tests/test_extension_types/` round-trip tests verify the full conversion chain; these must pass. +3. The deleted `tests/test_semantic_types/` tests are superseded by the extension type integration tests. +4. Run: `uv run pytest tests/test_hashing/ tests/test_extension_types/ tests/test_core/ -x` + +--- + +## Implementation order + +1. **Rename `TypeHandlerProtocol` → `PythonTypeSemanticHasherProtocol`**, change `handle() -> Any` to + `hash() -> ContentHash` in `protocols/hashing_protocols.py`. Update docstring. +2. **Rename `TypeHandlerRegistry` → `PythonTypeSemanticHasherRegistry`**, rename all registry methods + (`get_handler` → `get_semantic_hasher`, `has_handler` → `has_semantic_hasher`, etc.), + rename `BuiltinTypeHandlerRegistry` → `BuiltinPythonTypeSemanticHasherRegistry`. +3. **Update all builtin handler classes** in `builtin_handlers.py`: rename each class, change `handle()` → + `hash()`, update return type from `Any` → `ContentHash`, update implementations to return `ContentHash` + directly. Rename `register_builtin_handlers` → `register_builtin_python_type_semantic_hashers`. +4. **Rename `BaseSemanticHasher` → `SemanticAwarePythonHasher`** in `semantic_hasher.py`: simplify + `hash_object()` dispatch (remove double-wrap), rename `type_handler_registry` property → + `type_semantic_hasher_registry`, rename `get_default_type_handler_registry` → + `get_default_python_type_semantic_hasher_registry`. +5. **Update `__init__.py` exports** in `hashing/` and `hashing/semantic_hashing/` to use new names. +6. **Add `visit_extension` to `ArrowTypeDataVisitor`**; update `visit()` dispatch. +7. **Rewrite `SemanticHashingVisitor`** constructor and `visit_extension` implementation. +8. **Update `StarfixArrowHasher`**: new constructor signature, `_process_table_columns` short-circuit fix, + delete `SemanticArrowHasher`. +9. **Update `v0.1.json`** and **`context_schema.json`**. +10. **Update `versioned_hashers.py`** to source from `DataContext`. +11. **Delete** old semantic type files and their tests. +12. **Run grep sweep** for stale references; fix any found. +13. **Run full test suite**: `uv run pytest tests/test_hashing/ tests/test_extension_types/ tests/test_core/ -x` diff --git a/test-objective/unit/test_hashing.py b/test-objective/unit/test_hashing.py index c2083c21..a6928ef3 100644 --- a/test-objective/unit/test_hashing.py +++ b/test-objective/unit/test_hashing.py @@ -1,4 +1,4 @@ -"""Tests for BaseSemanticHasher and TypeHandlerRegistry. +"""Tests for SemanticAwarePythonHasher and PythonTypeHandlerRegistry. Specification-derived tests covering deterministic hashing of primitives, structures, ContentHash pass-through, identity_structure resolution, @@ -13,10 +13,10 @@ import pytest -from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher from orcapod.hashing.semantic_hashing.type_handler_registry import ( - BuiltinTypeHandlerRegistry, - TypeHandlerRegistry, + BuiltinPythonTypeHandlerRegistry, + PythonTypeHandlerRegistry, ) from orcapod.types import ContentHash @@ -27,15 +27,15 @@ @pytest.fixture -def registry() -> TypeHandlerRegistry: - """An empty TypeHandlerRegistry.""" - return TypeHandlerRegistry() +def registry() -> PythonTypeHandlerRegistry: + """An empty PythonTypeHandlerRegistry.""" + return PythonTypeHandlerRegistry() @pytest.fixture -def hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: - """A strict BaseSemanticHasher backed by an empty registry.""" - return BaseSemanticHasher( +def hasher(registry: PythonTypeHandlerRegistry) -> SemanticAwarePythonHasher: + """A strict SemanticAwarePythonHasher backed by an empty registry.""" + return SemanticAwarePythonHasher( hasher_id="test_v1", type_handler_registry=registry, strict=True, @@ -43,9 +43,9 @@ def hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: @pytest.fixture -def lenient_hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: - """A non-strict BaseSemanticHasher backed by an empty registry.""" - return BaseSemanticHasher( +def lenient_hasher(registry: PythonTypeHandlerRegistry) -> SemanticAwarePythonHasher: + """A non-strict SemanticAwarePythonHasher backed by an empty registry.""" + return SemanticAwarePythonHasher( hasher_id="test_v1", type_handler_registry=registry, strict=False, @@ -58,13 +58,13 @@ def lenient_hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: class _FakeHandler: - """Minimal object satisfying TypeHandlerProtocol for testing.""" + """Minimal object satisfying PythonTypeHandlerProtocol for testing.""" def __init__(self, return_value: Any = "handled") -> None: self._return_value = return_value - def handle(self, obj: Any, hasher: BaseSemanticHasher) -> Any: - return self._return_value + def handle(self, obj: Any, hasher: SemanticAwarePythonHasher) -> Any: + return str(self._return_value) class _IdentityObj: @@ -79,18 +79,18 @@ def identity_structure(self) -> Any: def content_hash(self, hasher: Any = None) -> ContentHash: if hasher is not None: return hasher.hash_object(self.identity_structure()) - h = BaseSemanticHasher( - "test_v1", type_handler_registry=TypeHandlerRegistry(), strict=False + h = SemanticAwarePythonHasher( + "test_v1", type_handler_registry=PythonTypeHandlerRegistry(), strict=False ) return h.hash_object(self.identity_structure()) # =================================================================== -# BaseSemanticHasher -- primitive hashing +# SemanticAwarePythonHasher -- primitive hashing # =================================================================== -class TestBaseSemanticHasherPrimitives: +class TestSemanticAwarePythonHasherPrimitives: """Primitives (int, str, float, bool, None) are hashed deterministically.""" @pytest.mark.parametrize( @@ -99,21 +99,21 @@ class TestBaseSemanticHasherPrimitives: ids=lambda v: f"{type(v).__name__}({v!r})", ) def test_primitive_produces_content_hash( - self, hasher: BaseSemanticHasher, value: Any + self, hasher: SemanticAwarePythonHasher, value: Any ) -> None: result = hasher.hash_object(value) assert isinstance(result, ContentHash) @pytest.mark.parametrize("value", [42, "hello", 3.14, True, None]) def test_primitive_deterministic( - self, hasher: BaseSemanticHasher, value: Any + self, hasher: SemanticAwarePythonHasher, value: Any ) -> None: """Same input always produces the same hash.""" h1 = hasher.hash_object(value) h2 = hasher.hash_object(value) assert h1 == h2 - def test_different_primitives_differ(self, hasher: BaseSemanticHasher) -> None: + def test_different_primitives_differ(self, hasher: SemanticAwarePythonHasher) -> None: """Different inputs produce different hashes (collision resistance).""" h_int = hasher.hash_object(42) h_str = hasher.hash_object("42") @@ -121,48 +121,48 @@ def test_different_primitives_differ(self, hasher: BaseSemanticHasher) -> None: # =================================================================== -# BaseSemanticHasher -- structures +# SemanticAwarePythonHasher -- structures # =================================================================== -class TestBaseSemanticHasherStructures: +class TestSemanticAwarePythonHasherStructures: """Structures (list, dict, tuple, set) are expanded and hashed.""" - def test_list_hashed(self, hasher: BaseSemanticHasher) -> None: + def test_list_hashed(self, hasher: SemanticAwarePythonHasher) -> None: result = hasher.hash_object([1, 2, 3]) assert isinstance(result, ContentHash) - def test_dict_hashed(self, hasher: BaseSemanticHasher) -> None: + def test_dict_hashed(self, hasher: SemanticAwarePythonHasher) -> None: result = hasher.hash_object({"a": 1, "b": 2}) assert isinstance(result, ContentHash) - def test_tuple_hashed(self, hasher: BaseSemanticHasher) -> None: + def test_tuple_hashed(self, hasher: SemanticAwarePythonHasher) -> None: result = hasher.hash_object((1, 2, 3)) assert isinstance(result, ContentHash) - def test_set_hashed(self, hasher: BaseSemanticHasher) -> None: + def test_set_hashed(self, hasher: SemanticAwarePythonHasher) -> None: result = hasher.hash_object({1, 2, 3}) assert isinstance(result, ContentHash) - def test_list_and_tuple_differ(self, hasher: BaseSemanticHasher) -> None: + def test_list_and_tuple_differ(self, hasher: SemanticAwarePythonHasher) -> None: """list and tuple with same elements produce different hashes.""" h_list = hasher.hash_object([1, 2, 3]) h_tuple = hasher.hash_object((1, 2, 3)) assert h_list != h_tuple - def test_set_order_independent(self, hasher: BaseSemanticHasher) -> None: + def test_set_order_independent(self, hasher: SemanticAwarePythonHasher) -> None: """Sets with the same elements hash identically regardless of insertion order.""" h1 = hasher.hash_object({3, 1, 2}) h2 = hasher.hash_object({1, 2, 3}) assert h1 == h2 - def test_dict_key_order_independent(self, hasher: BaseSemanticHasher) -> None: + def test_dict_key_order_independent(self, hasher: SemanticAwarePythonHasher) -> None: """Dicts with the same key-value pairs hash identically regardless of order.""" h1 = hasher.hash_object({"b": 2, "a": 1}) h2 = hasher.hash_object({"a": 1, "b": 2}) assert h1 == h2 - def test_nested_structures(self, hasher: BaseSemanticHasher) -> None: + def test_nested_structures(self, hasher: SemanticAwarePythonHasher) -> None: """Nested structures are hashed correctly.""" nested = {"key": [1, (2, 3)], "other": {"inner": True}} result = hasher.hash_object(nested) @@ -170,48 +170,48 @@ def test_nested_structures(self, hasher: BaseSemanticHasher) -> None: # Determinism assert result == hasher.hash_object(nested) - def test_different_structures_differ(self, hasher: BaseSemanticHasher) -> None: + def test_different_structures_differ(self, hasher: SemanticAwarePythonHasher) -> None: h1 = hasher.hash_object([1, 2]) h2 = hasher.hash_object([1, 2, 3]) assert h1 != h2 # =================================================================== -# BaseSemanticHasher -- ContentHash passthrough +# SemanticAwarePythonHasher -- ContentHash passthrough # =================================================================== -class TestBaseSemanticHasherContentHash: +class TestSemanticAwarePythonHasherContentHash: """ContentHash inputs are returned as-is (terminal).""" - def test_content_hash_passthrough(self, hasher: BaseSemanticHasher) -> None: + def test_content_hash_passthrough(self, hasher: SemanticAwarePythonHasher) -> None: ch = ContentHash(method="sha256", digest=b"\x00" * 32) result = hasher.hash_object(ch) assert result is ch # =================================================================== -# BaseSemanticHasher -- identity_structure resolution +# SemanticAwarePythonHasher -- identity_structure resolution # =================================================================== -class TestBaseSemanticHasherIdentityStructure: +class TestSemanticAwarePythonHasherIdentityStructure: """Objects implementing identity_structure() are resolved via it.""" - def test_identity_structure_object(self, hasher: BaseSemanticHasher) -> None: + def test_identity_structure_object(self, hasher: SemanticAwarePythonHasher) -> None: obj = _IdentityObj(structure={"name": "test", "version": 1}) result = hasher.hash_object(obj) assert isinstance(result, ContentHash) def test_identity_structure_deterministic( - self, hasher: BaseSemanticHasher + self, hasher: SemanticAwarePythonHasher ) -> None: obj1 = _IdentityObj(structure=[1, 2, 3]) obj2 = _IdentityObj(structure=[1, 2, 3]) assert hasher.hash_object(obj1) == hasher.hash_object(obj2) def test_different_identity_structures_differ( - self, hasher: BaseSemanticHasher + self, hasher: SemanticAwarePythonHasher ) -> None: obj1 = _IdentityObj(structure="alpha") obj2 = _IdentityObj(structure="beta") @@ -219,22 +219,22 @@ def test_different_identity_structures_differ( # =================================================================== -# BaseSemanticHasher -- strict mode +# SemanticAwarePythonHasher -- strict mode # =================================================================== -class TestBaseSemanticHasherStrictMode: +class TestSemanticAwarePythonHasherStrictMode: """Unknown type in strict mode raises TypeError.""" - def test_unknown_type_strict_raises(self, hasher: BaseSemanticHasher) -> None: + def test_unknown_type_strict_raises(self, hasher: SemanticAwarePythonHasher) -> None: class Unknown: pass - with pytest.raises(TypeError, match="no TypeHandlerProtocol registered"): + with pytest.raises(TypeError, match="no implementation of PythonTypeHandlerProtocol registered"): hasher.hash_object(Unknown()) def test_unknown_type_lenient_succeeds( - self, lenient_hasher: BaseSemanticHasher + self, lenient_hasher: SemanticAwarePythonHasher ) -> None: class Unknown: pass @@ -244,26 +244,26 @@ class Unknown: # =================================================================== -# BaseSemanticHasher -- collision resistance +# SemanticAwarePythonHasher -- collision resistance # =================================================================== -class TestBaseSemanticHasherCollisionResistance: +class TestSemanticAwarePythonHasherCollisionResistance: """Different inputs produce different hashes.""" - def test_int_vs_string(self, hasher: BaseSemanticHasher) -> None: + def test_int_vs_string(self, hasher: SemanticAwarePythonHasher) -> None: assert hasher.hash_object(1) != hasher.hash_object("1") - def test_empty_list_vs_empty_tuple(self, hasher: BaseSemanticHasher) -> None: + def test_empty_list_vs_empty_tuple(self, hasher: SemanticAwarePythonHasher) -> None: assert hasher.hash_object([]) != hasher.hash_object(()) - def test_empty_dict_vs_empty_list(self, hasher: BaseSemanticHasher) -> None: + def test_empty_dict_vs_empty_list(self, hasher: SemanticAwarePythonHasher) -> None: assert hasher.hash_object({}) != hasher.hash_object([]) - def test_none_vs_string_none(self, hasher: BaseSemanticHasher) -> None: + def test_none_vs_string_none(self, hasher: SemanticAwarePythonHasher) -> None: assert hasher.hash_object(None) != hasher.hash_object("None") - def test_true_vs_one(self, hasher: BaseSemanticHasher) -> None: + def test_true_vs_one(self, hasher: SemanticAwarePythonHasher) -> None: """bool True and int 1 produce different hashes due to JSON encoding.""" h_true = hasher.hash_object(True) h_one = hasher.hash_object(1) @@ -271,34 +271,34 @@ def test_true_vs_one(self, hasher: BaseSemanticHasher) -> None: # =================================================================== -# TypeHandlerRegistry -- register/get_handler roundtrip +# PythonTypeHandlerRegistry -- register/get_handler roundtrip # =================================================================== -class TestTypeHandlerRegistryBasics: +class TestPythonTypeHandlerRegistryBasics: """register() + get_handler() roundtrip.""" - def test_register_and_get_handler(self, registry: TypeHandlerRegistry) -> None: + def test_register_and_get_handler(self, registry: PythonTypeHandlerRegistry) -> None: handler = _FakeHandler() registry.register(int, handler) assert registry.get_handler(42) is handler def test_get_handler_returns_none_for_unregistered( - self, registry: TypeHandlerRegistry + self, registry: PythonTypeHandlerRegistry ) -> None: assert registry.get_handler("hello") is None # =================================================================== -# TypeHandlerRegistry -- MRO-aware lookup +# PythonTypeHandlerRegistry -- MRO-aware lookup # =================================================================== -class TestTypeHandlerRegistryMRO: +class TestPythonTypeHandlerRegistryMRO: """MRO-aware lookup: handler for parent class matches subclass.""" def test_subclass_inherits_parent_handler( - self, registry: TypeHandlerRegistry + self, registry: PythonTypeHandlerRegistry ) -> None: class Base: pass @@ -311,7 +311,7 @@ class Child(Base): assert registry.get_handler(Child()) is handler def test_specific_handler_overrides_parent( - self, registry: TypeHandlerRegistry + self, registry: PythonTypeHandlerRegistry ) -> None: class Base: pass @@ -328,41 +328,41 @@ class Child(Base): # =================================================================== -# TypeHandlerRegistry -- unregister +# PythonTypeHandlerRegistry -- unregister # =================================================================== -class TestTypeHandlerRegistryUnregister: +class TestPythonTypeHandlerRegistryUnregister: """unregister() removes handler.""" - def test_unregister_existing(self, registry: TypeHandlerRegistry) -> None: + def test_unregister_existing(self, registry: PythonTypeHandlerRegistry) -> None: handler = _FakeHandler() registry.register(int, handler) result = registry.unregister(int) assert result is True assert registry.get_handler(42) is None - def test_unregister_nonexistent(self, registry: TypeHandlerRegistry) -> None: + def test_unregister_nonexistent(self, registry: PythonTypeHandlerRegistry) -> None: result = registry.unregister(float) assert result is False # =================================================================== -# TypeHandlerRegistry -- has_handler +# PythonTypeHandlerRegistry -- has_handler # =================================================================== -class TestTypeHandlerRegistryHasHandler: +class TestPythonTypeHandlerRegistryHasHandler: """has_handler() boolean check.""" - def test_has_handler_true(self, registry: TypeHandlerRegistry) -> None: + def test_has_handler_true(self, registry: PythonTypeHandlerRegistry) -> None: registry.register(int, _FakeHandler()) assert registry.has_handler(int) is True - def test_has_handler_false(self, registry: TypeHandlerRegistry) -> None: + def test_has_handler_false(self, registry: PythonTypeHandlerRegistry) -> None: assert registry.has_handler(str) is False - def test_has_handler_via_mro(self, registry: TypeHandlerRegistry) -> None: + def test_has_handler_via_mro(self, registry: PythonTypeHandlerRegistry) -> None: class Base: pass @@ -374,17 +374,17 @@ class Child(Base): # =================================================================== -# TypeHandlerRegistry -- registered_types +# PythonTypeHandlerRegistry -- registered_types # =================================================================== -class TestTypeHandlerRegistryRegisteredTypes: +class TestPythonTypeHandlerRegistryRegisteredTypes: """registered_types() lists types.""" - def test_registered_types_empty(self, registry: TypeHandlerRegistry) -> None: + def test_registered_types_empty(self, registry: PythonTypeHandlerRegistry) -> None: assert registry.registered_types() == [] - def test_registered_types_populated(self, registry: TypeHandlerRegistry) -> None: + def test_registered_types_populated(self, registry: PythonTypeHandlerRegistry) -> None: registry.register(int, _FakeHandler()) registry.register(str, _FakeHandler()) types = registry.registered_types() @@ -392,14 +392,14 @@ def test_registered_types_populated(self, registry: TypeHandlerRegistry) -> None # =================================================================== -# TypeHandlerRegistry -- thread safety +# PythonTypeHandlerRegistry -- thread safety # =================================================================== -class TestTypeHandlerRegistryThreadSafety: +class TestPythonTypeHandlerRegistryThreadSafety: """Concurrent register/lookup doesn't crash.""" - def test_concurrent_register_lookup(self, registry: TypeHandlerRegistry) -> None: + def test_concurrent_register_lookup(self, registry: PythonTypeHandlerRegistry) -> None: errors: list[Exception] = [] def register_types(start: int, count: int) -> None: @@ -435,13 +435,13 @@ def lookup_types() -> None: # =================================================================== -# BuiltinTypeHandlerRegistry +# BuiltinPythonTypeHandlerRegistry # =================================================================== -class TestBuiltinTypeHandlerRegistry: - """BuiltinTypeHandlerRegistry is pre-populated with built-in handlers.""" +class TestBuiltinPythonTypeHandlerRegistry: + """BuiltinPythonTypeHandlerRegistry is pre-populated with built-in handlers.""" def test_construction(self) -> None: - reg = BuiltinTypeHandlerRegistry() + reg = BuiltinPythonTypeHandlerRegistry() assert len(reg.registered_types()) > 0 diff --git a/tests/test_core/function_pod/test_write_side_registration.py b/tests/test_core/function_pod/test_write_side_registration.py new file mode 100644 index 00000000..db1f5de5 --- /dev/null +++ b/tests/test_core/function_pod/test_write_side_registration.py @@ -0,0 +1,430 @@ +"""Tests for write-side LogicalType auto-registration at function pod declaration. + +These tests verify that _FunctionPodBase.__init__ triggers factory synthesis for +any non-native Python types in the pod's input/output schemas, and raises TypeError +at declaration time when no factory is registered. +""" + +from __future__ import annotations + +import uuid as _uuid_module +from typing import Optional + +import pyarrow as pa +import pytest + +from orcapod.contexts import get_default_context +from orcapod.contexts.core import DataContext +from orcapod.core.data_function import PythonDataFunction +from orcapod.core.function_pod import FunctionPod +from orcapod.extension_types.protocols import LogicalTypeProtocol +from orcapod.extension_types.registry import ( + LogicalTypeRegistry, + make_arrow_extension_type, + make_polars_extension_type, +) +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _make_test_context(registry: LogicalTypeRegistry) -> DataContext: + """Create a DataContext with a fresh converter bound to the given registry. + + A fresh ``UniversalTypeConverter`` is constructed with ``logical_type_registry`` + set at construction time, which is the canonical way to bind the two objects. + """ + base_ctx = get_default_context() + fresh_converter = UniversalTypeConverter( + logical_type_registry=registry, + ) + return DataContext( + context_key="test", + version="test", + description="test", + type_converter=fresh_converter, + arrow_hasher=base_ctx.arrow_hasher, + semantic_hasher=base_ctx.semantic_hasher, + ) + + +def _make_logical_type(py_type: type) -> LogicalTypeProtocol: + """Synthesize a minimal LogicalType for py_type.""" + arrow_name = f"{py_type.__module__}.{py_type.__qualname__}.{_uuid_module.uuid4().hex[:6]}" + ArrowExt = make_arrow_extension_type(arrow_name, pa.large_string()) + PolarsExt = make_polars_extension_type(arrow_name, pa.large_string()) + + class _LT: + logical_type_name = arrow_name + python_type = py_type + def get_arrow_extension_type(self): return ArrowExt() + def get_polars_extension_type(self): return PolarsExt() + def python_to_storage(self, v, converter=None): return str(v) + def storage_to_python(self, v, converter=None): return v + + return _LT() + + +def _make_registry_with_factory(*target_bases: type) -> tuple[LogicalTypeRegistry, list[type]]: + """Return a registry with a factory covering all target_bases and a call log.""" + call_log: list[type] = [] + + class _Factory: + def supports_class(self, python_type): + return any(issubclass(python_type, base) for base in target_bases) + + def reconstruct_from_arrow(self, name, storage, meta, converter): + return _make_logical_type(object) + + def create_for_python_type(self, python_type, converter): + call_log.append(python_type) + return _make_logical_type(python_type) + + registry = LogicalTypeRegistry() + registry.register_logical_type_factory(_Factory(), python_bases=list(target_bases)) + return registry, call_log + + +# ── Custom classes used in tests ───────────────────────────────────────────── + +class _MyBase: + pass + + +class _MyChild(_MyBase): + pass + + +class _MyOtherBase: + pass + + +class _MyOtherChild(_MyOtherBase): + pass + + +class _ThirdBase: + pass + + +class _ThirdChild(_ThirdBase): + pass + + +# ── Basic triggering tests ──────────────────────────────────────────────────── + +def test_pod_declaration_triggers_factory_for_input_type(): + """Declaring a FunctionPod with a custom input type causes factory synthesis.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: _MyChild) -> str: + return str(x) + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert registry.get_by_python_type(_MyChild) is not None + + +def test_pod_declaration_triggers_factory_for_output_type(): + """Declaring a FunctionPod with a custom output type causes factory synthesis.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: int) -> _MyChild: + return _MyChild() + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert registry.get_by_python_type(_MyChild) is not None + + +# ── Complex / nested type tests ─────────────────────────────────────────────── + +@pytest.mark.xfail( + reason="list[T] where T is a logical type not yet supported (PLT-1732: ListLogicalType)", + raises=ValueError, + strict=True, +) +def test_pod_declaration_with_nested_list_input(): + """list[_MyChild] in a function input causes factory synthesis for _MyChild.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(items: list[_MyChild]) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + + +@pytest.mark.xfail( + reason="list[T] / dict[K, list[T]] where T is a logical type not yet supported (PLT-1732)", + raises=ValueError, + strict=True, +) +def test_pod_declaration_with_doubly_nested_input(): + """dict[str, list[_MyChild]] causes factory synthesis for _MyChild.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(mapping: dict[str, list[_MyChild]]) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + + +def test_pod_declaration_with_optional_input(): + """Optional[_MyChild] causes factory synthesis for _MyChild.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: Optional[_MyChild]) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + + +@pytest.mark.xfail( + reason="list[T] where T is a logical type not yet supported (PLT-1732: ListLogicalType)", + raises=ValueError, + strict=True, +) +def test_pod_declaration_with_complex_output(): + """list[_MyChild] in the output schema causes factory synthesis.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: str) -> list[_MyChild]: + return [] + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + + +@pytest.mark.xfail( + reason="list[T] / dict[K, list[T]] where T is a logical type not yet supported (PLT-1732)", + raises=ValueError, + strict=True, +) +def test_pod_declaration_with_doubly_nested_output(): + """dict[str, list[_MyChild]] in the output causes factory synthesis for _MyChild.""" + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: int) -> dict[str, list[_MyChild]]: + return {} + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + + +# ── Multi-class tests ───────────────────────────────────────────────────────── + +def test_pod_declaration_two_classes_one_in_input_one_in_output(): + """Two different custom classes — one in input, one in output — each gets synthesized.""" + registry, call_log = _make_registry_with_factory(_MyBase, _MyOtherBase) + ctx = _make_test_context(registry) + + def my_func(x: _MyChild) -> _MyOtherChild: + return _MyOtherChild() + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert _MyOtherChild in call_log + + +def test_pod_declaration_two_classes_both_in_input(): + """Two different custom classes both as inputs each get synthesized.""" + registry, call_log = _make_registry_with_factory(_MyBase, _MyOtherBase) + ctx = _make_test_context(registry) + + def my_func(x: _MyChild, y: _MyOtherChild) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert _MyOtherChild in call_log + + +def test_pod_declaration_two_classes_both_in_output(): + """Two different custom classes both as outputs each get synthesized.""" + registry, call_log = _make_registry_with_factory(_MyBase, _MyOtherBase) + ctx = _make_test_context(registry) + + def my_func(x: int) -> tuple[_MyChild, _MyOtherChild]: + return _MyChild(), _MyOtherChild() + + FunctionPod( + data_function=PythonDataFunction( + my_func, + output_keys=["first", "second"], + ), + data_context=ctx, + ) + assert _MyChild in call_log + assert _MyOtherChild in call_log + + +@pytest.mark.xfail( + reason="list[T] where T is a logical type not yet supported (PLT-1732: ListLogicalType)", + raises=ValueError, + strict=True, +) +def test_pod_declaration_three_classes_mixed(): + """Three custom classes spread across input and output each get synthesized.""" + registry, call_log = _make_registry_with_factory(_MyBase, _MyOtherBase, _ThirdBase) + ctx = _make_test_context(registry) + + def my_func(a: _MyChild, b: list[_MyOtherChild]) -> _ThirdChild: + return _ThirdChild() + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert _MyOtherChild in call_log + assert _ThirdChild in call_log + + +def test_pod_declaration_three_classes_all_in_input(): + """Three custom classes all in input parameters each get synthesized.""" + registry, call_log = _make_registry_with_factory(_MyBase, _MyOtherBase, _ThirdBase) + ctx = _make_test_context(registry) + + def my_func(a: _MyChild, b: _MyOtherChild, c: _ThirdChild) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert _MyOtherChild in call_log + assert _ThirdChild in call_log + + +# ── Skip / guard tests ──────────────────────────────────────────────────────── + +def test_pod_declaration_native_types_no_factory_call(): + """Pods using only native types (int, str, etc.) never trigger factory lookup.""" + + class _NeverCalledFactory: + def supports_class(self, python_type): + return True + def reconstruct_from_arrow(self, name, storage, meta, converter): ... + def create_for_python_type(self, pt, converter): + raise AssertionError(f"factory called for {pt!r}") + + registry = LogicalTypeRegistry() + registry.register_logical_type_factory(_NeverCalledFactory(), python_bases=[object]) + ctx = _make_test_context(registry) + + def my_func(x: int, y: str) -> float: + return 0.0 + + # Should not raise — int, str, float are native + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + + +def test_pod_declaration_raises_type_error_for_unhandled_class(): + """Pod with a type that has no registered factory raises TypeError at declaration.""" + registry = LogicalTypeRegistry() # empty — no factories + ctx = _make_test_context(registry) + + def my_func(x: _MyChild) -> str: + return "" + + with pytest.raises(TypeError, match="No LogicalType or LogicalTypeFactory"): + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + + +def test_pod_declaration_raises_for_nested_unhandled_class(): + """TypeError is raised even when the custom type is nested inside list[T].""" + registry = LogicalTypeRegistry() # empty — no factories + ctx = _make_test_context(registry) + + def my_func(items: list[_MyChild]) -> str: + return "" + + with pytest.raises(TypeError, match="No LogicalType or LogicalTypeFactory"): + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + + +def test_pod_declaration_already_registered_type_no_factory_call(): + """Pre-registered types are not passed to the factory.""" + registry, call_log = _make_registry_with_factory(_MyBase) + # Pre-register _MyChild directly + registry.register_logical_type(_make_logical_type(_MyChild)) + ctx = _make_test_context(registry) + + def my_func(x: _MyChild) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + # Factory was NOT called — _MyChild was already registered + assert _MyChild not in call_log + + +def test_pod_declaration_with_union_none_syntax(): + """``_MyChild | None`` (new-style union) causes factory synthesis for _MyChild. + + Python 3.10+ ``X | Y`` produces a ``types.UnionType``, which is a different + runtime object from ``typing.Union[X, Y]``. This test confirms that + ``extract_leaf_classes`` correctly unwraps both union forms and that + ``NoneType`` is skipped in both cases. + """ + registry, call_log = _make_registry_with_factory(_MyBase) + ctx = _make_test_context(registry) + + def my_func(x: _MyChild | None) -> str: + return "" + + FunctionPod( + data_function=PythonDataFunction(my_func, output_keys=["result"]), + data_context=ctx, + ) + assert _MyChild in call_log + assert registry.get_by_python_type(_MyChild) is not None diff --git a/tests/test_databases/test_connector_arrow_database.py b/tests/test_databases/test_connector_arrow_database.py index d87701b3..71125ef2 100644 --- a/tests/test_databases/test_connector_arrow_database.py +++ b/tests/test_databases/test_connector_arrow_database.py @@ -18,6 +18,7 @@ 11. Flush behaviour (pending cleared, connector receives data) 12. Config (to_config shape, from_config raises NotImplementedError) 13. at() method and base_path attribute +14. Extension-type write guard """ from __future__ import annotations @@ -783,3 +784,107 @@ def test_at_rejects_null_in_component(self, db): def test_at_rejects_empty_component(self, db): with pytest.raises(ValueError): db.at("") + + +# --------------------------------------------------------------------------- +# 14. Extension-type write guard +# --------------------------------------------------------------------------- + + +class TestExtensionTypeWriteGuard: + """add_records() rejects extension-typed columns. + + SQL connectors do not preserve ``ARROW:extension:*`` field metadata. + Writing extension-typed columns would cause silent type loss on read. + The guard fires at write time so the problem is surfaced immediately + rather than discovered when reading back corrupted data. + + Two representations are tested: + - In-memory ``pa.ExtensionType`` (the type is registered in this process). + - Metadata-only columns (plain storage type + ``ARROW:extension:name`` + field metadata, as produced when reading Parquet from a process that + had the type registered). + """ + + @pytest.fixture + def db(self): + return ConnectorArrowDatabase(MockDBConnector()) + + def test_rejects_in_memory_extension_type_column(self, db): + """add_records raises ValueError when a column carries a pa.ExtensionType.""" + import pyarrow as pa + + # Build a minimal custom extension type for testing. + class _DummyExt(pa.ExtensionType): + def __init__(self): + super().__init__(pa.large_string(), "test.dummy") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls() + + pa.register_extension_type(_DummyExt()) + try: + ext_array = pa.array(["hello"], type=_DummyExt()) + rid_array = pa.array([b"id1"], type=pa.large_binary()) + table = pa.table( + {"__record_id": rid_array, "payload": ext_array}, + ) + with pytest.raises(ValueError, match="extension"): + db.add_records( + ("results",), + table, + record_id_column="__record_id", + ) + finally: + pa.unregister_extension_type("test.dummy") + + def test_rejects_metadata_only_extension_column(self, db): + """add_records raises ValueError when a column has ARROW:extension:name field metadata. + + This is the "unregistered read" representation: the column type is a plain + storage type (e.g. large_string) but the field metadata contains the + ``b"ARROW:extension:name"`` key, as happens when reading a Parquet file that + was written with an extension type that is not registered in the current process. + """ + import pyarrow as pa + + ext_field = pa.field( + "payload", + pa.large_string(), + metadata={ + b"ARROW:extension:name": b"orcapod.path", + b"ARROW:extension:metadata": b"", + }, + ) + rid_field = pa.field("__record_id", pa.large_binary()) + schema = pa.schema([rid_field, ext_field]) + table = pa.table( + { + "__record_id": pa.array([b"id1"], type=pa.large_binary()), + "payload": pa.array(["/tmp/test"], type=pa.large_string()), + }, + schema=schema, + ) + with pytest.raises(ValueError, match="extension"): + db.add_records( + ("results",), + table, + record_id_column="__record_id", + ) + + def test_plain_column_not_rejected(self, db): + """add_records accepts tables with no extension-typed columns.""" + import pyarrow as pa + + table = pa.table( + { + "__record_id": pa.array([b"id1"], type=pa.large_binary()), + "value": pa.array([42], type=pa.int64()), + } + ) + # Should not raise + db.add_records(("results",), table, record_id_column="__record_id") diff --git a/tests/test_databases/test_extension_aware_database.py b/tests/test_databases/test_extension_aware_database.py new file mode 100644 index 00000000..259d6a8c --- /dev/null +++ b/tests/test_databases/test_extension_aware_database.py @@ -0,0 +1,195 @@ +"""Tests for ExtensionAwareDatabase.""" + +from __future__ import annotations + +import uuid + +import pyarrow as pa +import pytest + +from orcapod.databases.extension_aware_database import ExtensionAwareDatabase +from orcapod.databases.in_memory_databases import InMemoryArrowDatabase +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + return f"test.eadb.{uuid.uuid4().hex[:8]}" + + +def _make_converter_with_type( + arrow_name: str, + storage: pa.DataType = pa.large_utf8(), +): + """Return a (converter, ext_type_instance) pair with one registered type.""" + import polars as pl + + ExtCls = make_arrow_extension_type(arrow_name, storage) + ext_type = ExtCls() + pl_storage = pl.from_arrow(pa.array([], type=storage)).dtype + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(arrow_name, pl_storage, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _LT: + @property + def logical_type_name(self): + return arrow_name + @property + def python_type(self): + return str + def get_arrow_extension_type(self): + return ext_type + def get_polars_extension_type(self): + return _PolarsExt() + def python_to_storage(self, v, converter=None): + return str(v) + def storage_to_python(self, v, converter=None): + return v + + registry = LogicalTypeRegistry() + registry.register_logical_type(_LT()) + converter = UniversalTypeConverter(logical_type_registry=registry) + return converter, ext_type + + +def _make_converter(): + """Make a converter with an empty registry.""" + registry = LogicalTypeRegistry() + return UniversalTypeConverter(logical_type_registry=registry) + + +def _degraded_table(arrow_name: str, storage: pa.DataType, values: list) -> pa.Table: + """Arrow table with extension field metadata but storage type (simulates unregistered read).""" + col = pa.array(values, type=storage) + field = pa.field("col", storage).with_metadata({ + b"ARROW:extension:name": arrow_name.encode(), + b"ARROW:extension:metadata": b"", + }) + return pa.table({"col": col}, schema=pa.schema([field])) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_get_all_records_applies_extension_types(): + """get_all_records returns table with extension types applied.""" + name = _unique_name() + converter, ext_type = _make_converter_with_type(name) + + inner_db = InMemoryArrowDatabase() + # Add two separate records (distinct record_ids) so both rows survive deduplication. + r1 = _degraded_table(name, pa.large_utf8(), ["hello"]) + r2 = _degraded_table(name, pa.large_utf8(), ["world"]) + inner_db.add_record(("test",), record_id=b"r1", record=r1, flush=False) + inner_db.add_record(("test",), record_id=b"r2", record=r2, flush=True) + + db = ExtensionAwareDatabase(inner_db, converter) + result = db.get_all_records(("test",)) + + assert result is not None + assert result.schema.field("col").type == ext_type + assert sorted(result.column("col").to_pylist()) == ["hello", "world"] + + +def test_get_record_by_id_applies_extension_types(): + """get_record_by_id returns table with extension types applied.""" + name = _unique_name() + converter, ext_type = _make_converter_with_type(name) + + inner_db = InMemoryArrowDatabase() + degraded = _degraded_table(name, pa.large_utf8(), ["x"]) + inner_db.add_record(("p",), record_id=b"r1", record=degraded, flush=True) + + db = ExtensionAwareDatabase(inner_db, converter) + result = db.get_record_by_id(("p",), b"r1") + + assert result is not None + assert result.schema.field("col").type == ext_type + + +def test_get_records_by_ids_applies_extension_types(): + """get_records_by_ids returns table with extension types applied.""" + name = _unique_name() + converter, ext_type = _make_converter_with_type(name) + + inner_db = InMemoryArrowDatabase() + degraded = _degraded_table(name, pa.large_utf8(), ["a"]) + inner_db.add_record(("p",), record_id=b"r1", record=degraded, flush=True) + + db = ExtensionAwareDatabase(inner_db, converter) + result = db.get_records_by_ids(("p",), [b"r1"]) + + assert result is not None + assert result.schema.field("col").type == ext_type + + +def test_get_all_records_returns_none_when_no_records(): + """Returns None when the underlying database has no records for the path.""" + converter = _make_converter() + inner_db = InMemoryArrowDatabase() + db = ExtensionAwareDatabase(inner_db, converter) + + assert db.get_all_records(("nonexistent",)) is None + + +def test_write_methods_passthrough(): + """add_record and add_records write correctly through the wrapper.""" + converter = _make_converter() + inner_db = InMemoryArrowDatabase() + db = ExtensionAwareDatabase(inner_db, converter) + + t1 = pa.table({"x": pa.array([1], type=pa.int32())}) + t2 = pa.table({"x": pa.array([2], type=pa.int32())}) + db.add_record(("p",), record_id=b"r1", record=t1, flush=False) + db.add_record(("p",), record_id=b"r2", record=t2, flush=True) + + result = inner_db.get_all_records(("p",)) + assert result is not None + assert sorted(result.column("x").to_pylist()) == [1, 2] + + +def test_at_returns_extension_aware_database(): + """at() returns an ExtensionAwareDatabase with the same converter.""" + converter = _make_converter() + inner_db = InMemoryArrowDatabase() + db = ExtensionAwareDatabase(inner_db, converter) + + scoped = db.at("sub", "path") + + assert isinstance(scoped, ExtensionAwareDatabase) + assert scoped._converter is converter + assert scoped.base_path == ("sub", "path") + + +def test_base_path_delegates_to_inner(): + """base_path reflects the inner database's base_path.""" + converter = _make_converter() + inner_db = InMemoryArrowDatabase() + db = ExtensionAwareDatabase(inner_db, converter) + + assert db.base_path == () + assert db.at("a").base_path == ("a",) + + +def test_plain_table_passthrough_unchanged(): + """Tables with no extension type metadata are returned as-is (no wrapping overhead).""" + converter = _make_converter() + inner_db = InMemoryArrowDatabase() + db = ExtensionAwareDatabase(inner_db, converter) + + table = pa.table({"n": pa.array([10, 20], type=pa.int64())}) + inner_db.add_record(("p",), record_id=b"r1", record=table, flush=True) + + result = db.get_all_records(("p",)) + assert result is not None + assert result.schema.field("n").type == pa.int64() diff --git a/tests/test_extension_types/test_apply_extension_types.py b/tests/test_extension_types/test_apply_extension_types.py new file mode 100644 index 00000000..bf2e6016 --- /dev/null +++ b/tests/test_extension_types/test_apply_extension_types.py @@ -0,0 +1,300 @@ +"""Tests for apply_extension_types in database_hooks.""" + +from __future__ import annotations + +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + return f"test.apply.{uuid.uuid4().hex[:8]}" + + +def _make_registry_with_type( + arrow_name: str, + storage: pa.DataType = pa.large_utf8(), +) -> tuple[LogicalTypeRegistry, pa.ExtensionType]: + """Return a registry with one registered extension type and the type instance.""" + import polars as pl + + ExtCls = make_arrow_extension_type(arrow_name, storage) + ext_type = ExtCls() + pl_storage = pl.from_arrow(pa.array([], type=storage)).dtype + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(arrow_name, pl_storage, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _LT: + @property + def logical_type_name(self): + return arrow_name + @property + def python_type(self): + return str + def get_arrow_extension_type(self): + return ext_type + def get_polars_extension_type(self): + return _PolarsExt() + def python_to_storage(self, v): + return str(v) + def storage_to_python(self, v): + return v + + registry = LogicalTypeRegistry() + registry.register_logical_type(_LT()) + return registry, ext_type + + +def _degraded_table_with_metadata( + arrow_name: str, + storage: pa.DataType, + values: list, +) -> pa.Table: + """Build a table that carries extension field metadata but uses storage type. + + Simulates what you get when Arrow reads a Parquet/IPC file whose extension + type was not registered at read time. + """ + col = pa.array(values, type=storage) + field = pa.field("col", storage).with_metadata({ + b"ARROW:extension:name": arrow_name.encode(), + b"ARROW:extension:metadata": b"", + }) + schema = pa.schema([field]) + return pa.table({"col": col}, schema=schema) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_noop_when_no_extension_metadata(): + """Table with plain Arrow types is returned unchanged.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + registry = LogicalTypeRegistry() + table = pa.table({"x": pa.array([1, 2, 3], type=pa.int32())}) + result = apply_extension_types(table, registry) + assert result is table # same object — nothing to do + + +def test_wraps_storage_column_into_extension_type(): + """A column with extension field metadata is re-wrapped into the registered type.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + table = _degraded_table_with_metadata(name, pa.large_utf8(), ["hello", "world"]) + + result = apply_extension_types(table, registry) + + assert result.schema.field("col").type == ext_type + assert result.column("col").to_pylist() == ["hello", "world"] + + +def test_zero_copy_single_chunk(): + """from_storage wrapping shares the underlying buffer — no data copy.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, _ = _make_registry_with_type(name, pa.large_utf8()) + table = _degraded_table_with_metadata(name, pa.large_utf8(), ["a", "b"]) + + result = apply_extension_types(table, registry) + + orig_buf = table.column("col").chunk(0).buffers()[2] + new_buf = result.column("col").chunk(0).buffers()[2] + assert orig_buf == new_buf + + +def test_zero_copy_multiple_chunks(): + """Multi-chunk columns are wrapped per-chunk, all buffers shared.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + + # Build a multi-chunk ChunkedArray with extension metadata on the field + c1 = pa.array(["x"], type=pa.large_utf8()) + c2 = pa.array(["y", "z"], type=pa.large_utf8()) + chunked = pa.chunked_array([c1, c2], type=pa.large_utf8()) + field = pa.field("col", pa.large_utf8()).with_metadata({ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }) + schema = pa.schema([field]) + table = pa.table({"col": chunked}, schema=schema) + + result = apply_extension_types(table, registry) + result_col = result.column("col") + + assert result.schema.field("col").type == ext_type + assert result_col.num_chunks == 2 + assert result_col.to_pylist() == ["x", "y", "z"] + # Buffer identity per chunk + for i, (orig, wrapped) in enumerate(zip(chunked.chunks, result_col.chunks)): + assert orig.buffers()[2] == wrapped.buffers()[2], f"chunk {i} buffer differs" + + +def test_already_extension_type_passthrough(): + """Column already carrying an extension type is returned as-is.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + # Build a table with a properly typed extension column (already registered) + arr = pa.ExtensionArray.from_storage(ext_type, pa.array(["a"], type=pa.large_utf8())) + table = pa.table({"col": arr}) + + result = apply_extension_types(table, registry) + assert result is table + + +def test_unregistered_extension_metadata_left_as_storage(): + """A column whose extension type is not in the registry stays as storage type.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry = LogicalTypeRegistry() # no types registered + table = _degraded_table_with_metadata(name, pa.large_utf8(), ["v"]) + + result = apply_extension_types(table, registry) + + # Column stays as large_utf8 — registry has nothing to apply + assert result.schema.field("col").type == pa.large_utf8() + + +def test_nested_struct_extension_type(): + """Extension type inside a struct child field is reconstructed recursively.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + + # Build degraded struct: inner field has extension metadata but storage type + inner_field = pa.field("inner", pa.large_utf8()).with_metadata({ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }) + struct_type = pa.struct([inner_field]) + inner_data = pa.array(["p", "q"], type=pa.large_utf8()) + struct_col = pa.StructArray.from_arrays([inner_data], fields=[inner_field]) + schema = pa.schema([pa.field("s", struct_type)]) + table = pa.table({"s": struct_col}, schema=schema) + + result = apply_extension_types(table, registry) + + result_struct_type = result.schema.field("s").type + assert pa.types.is_struct(result_struct_type) + result_inner_field = result_struct_type.field("inner") + assert result_inner_field.type == ext_type + assert result.column("s").to_pylist() == [{"inner": "p"}, {"inner": "q"}] + + +def test_mixed_columns_only_ext_columns_changed(): + """Plain columns are left untouched when an extension column is processed.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + + ext_field = pa.field("ext_col", pa.large_utf8()).with_metadata({ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }) + plain_field = pa.field("plain_col", pa.int32()) + schema = pa.schema([ext_field, plain_field]) + table = pa.table( + {"ext_col": pa.array(["a"], type=pa.large_utf8()), "plain_col": pa.array([1], type=pa.int32())}, + schema=schema, + ) + + result = apply_extension_types(table, registry) + + assert result.schema.field("ext_col").type == ext_type + assert result.schema.field("plain_col").type == pa.int32() + assert result.column("plain_col").to_pylist() == [1] + + +def test_schema_level_metadata_preserved(): + """Schema-level metadata (e.g. pandas metadata) is preserved when rebuilding schema.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + + ext_field = pa.field("col", pa.large_utf8()).with_metadata({ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }) + schema_meta = {b"pandas": b'{"some": "pandas_metadata"}', b"custom": b"value"} + schema = pa.schema([ext_field], metadata=schema_meta) + table = pa.table({"col": pa.array(["x"], type=pa.large_utf8())}, schema=schema) + + result = apply_extension_types(table, registry) + + assert result.schema.field("col").type == ext_type + assert result.schema.metadata == schema_meta + + +def test_plain_struct_not_rebuilt(): + """A struct column with no extension children is returned as-is without rebuilding.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + registry = LogicalTypeRegistry() # empty — nothing registered + inner_field = pa.field("x", pa.int32()) + struct_type = pa.struct([inner_field]) + struct_col = pa.StructArray.from_arrays( + [pa.array([1, 2], type=pa.int32())], fields=[inner_field] + ) + schema = pa.schema([pa.field("s", struct_type)]) + table = pa.table({"s": struct_col}, schema=schema) + + result = apply_extension_types(table, registry) + + # Nothing changed — same object returned + assert result is table + + +def test_struct_null_bitmap_preserved(): + """Null struct rows retain their null status after extension type wrapping.""" + from orcapod.extension_types.database_hooks import apply_extension_types + + name = _unique_name() + registry, ext_type = _make_registry_with_type(name, pa.large_utf8()) + + inner_field = pa.field("inner", pa.large_utf8()).with_metadata({ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }) + struct_type = pa.struct([inner_field]) + inner_data = pa.array(["a", "b", "c"], type=pa.large_utf8()) + # Build struct with a null at position 1 + struct_col = pa.StructArray.from_arrays( + [inner_data], + fields=[inner_field], + mask=pa.array([False, True, False]), # True = null + ) + schema = pa.schema([pa.field("s", struct_type)]) + table = pa.table({"s": struct_col}, schema=schema) + + result = apply_extension_types(table, registry) + + result_col = result.column("s") + assert result_col.null_count == 1 + rows = result_col.to_pylist() + assert rows[0] is not None + assert rows[1] is None + assert rows[2] is not None diff --git a/tests/test_extension_types/test_builtin_logical_types.py b/tests/test_extension_types/test_builtin_logical_types.py new file mode 100644 index 00000000..5526486a --- /dev/null +++ b/tests/test_extension_types/test_builtin_logical_types.py @@ -0,0 +1,685 @@ +"""Tests for built-in LogicalType implementations (LogicalPath, LogicalUPath, LogicalUUID).""" + +from __future__ import annotations + +import pathlib +import uuid as uuid_module + +import polars as pl +import pyarrow as pa +from upath import UPath + +import orcapod + +from orcapod.extension_types.protocols import LogicalTypeProtocol +from orcapod.extension_types.registry import LogicalTypeRegistry + + +# --------------------------------------------------------------------------- +# LogicalPath tests +# --------------------------------------------------------------------------- + + +def test_logical_path_isinstance_logical_type(): + """LogicalPath() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert isinstance(LogicalPath(), LogicalTypeProtocol) + + +def test_logical_path_logical_type_name(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().logical_type_name == "orcapod.path" + + +def test_logical_path_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().python_type is pathlib.Path + + +def test_logical_path_arrow_ext_name(): + """get_arrow_extension_type().extension_name is 'orcapod.path'.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().get_arrow_extension_type().extension_name == "orcapod.path" + + +def test_logical_path_arrow_ext_storage_type(): + """Arrow extension storage type is pa.large_string().""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + assert LogicalPath().get_arrow_extension_type().storage_type == pa.large_string() + + +def test_logical_path_get_arrow_extension_type_is_cached(): + """get_arrow_extension_type() returns the same object on repeated calls.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_path_get_polars_extension_type_is_cached(): + """get_polars_extension_type() returns the same object on repeated calls.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + assert lt.get_polars_extension_type() is lt.get_polars_extension_type() + + +def test_logical_path_round_trip(): + """Path -> python_to_storage -> storage_to_python -> Path is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + p = pathlib.Path("/tmp/foo/bar.txt") + assert lt.storage_to_python(lt.python_to_storage(p)) == p + + +def test_logical_path_python_to_storage_returns_string(): + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + result = lt.python_to_storage(pathlib.Path("/tmp/test")) + assert isinstance(result, str) + assert result == "/tmp/test" + + +# --------------------------------------------------------------------------- +# LogicalUPath tests +# --------------------------------------------------------------------------- + + +def test_logical_upath_isinstance_logical_type(): + """LogicalUPath() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert isinstance(LogicalUPath(), LogicalTypeProtocol) + + +def test_logical_upath_logical_type_name(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().logical_type_name == "orcapod.upath" + + +def test_logical_upath_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().python_type is UPath + + +def test_logical_upath_arrow_ext_name(): + """get_arrow_extension_type().extension_name is 'orcapod.upath'.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().get_arrow_extension_type().extension_name == "orcapod.upath" + + +def test_logical_upath_arrow_ext_storage_type(): + """Arrow extension storage type is pa.large_string().""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + assert LogicalUPath().get_arrow_extension_type().storage_type == pa.large_string() + + +def test_logical_upath_get_arrow_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_upath_get_polars_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + assert lt.get_polars_extension_type() is lt.get_polars_extension_type() + + +def test_logical_upath_round_trip(): + """UPath -> python_to_storage -> storage_to_python -> UPath is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + up = UPath("s3://bucket/key/file.txt") + assert lt.storage_to_python(lt.python_to_storage(up)) == up + + +def test_logical_upath_python_to_storage_returns_string(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + result = lt.python_to_storage(UPath("s3://bucket/key")) + assert isinstance(result, str) + assert result == "s3://bucket/key" + + +# --------------------------------------------------------------------------- +# LogicalUUID tests +# --------------------------------------------------------------------------- + + +def test_logical_uuid_isinstance_logical_type(): + """LogicalUUID() satisfies the LogicalType runtime-checkable protocol.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert isinstance(LogicalUUID(), LogicalTypeProtocol) + + +def test_logical_uuid_logical_type_name(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert LogicalUUID().logical_type_name == "orcapod.uuid" + + +def test_logical_uuid_python_type(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert LogicalUUID().python_type is uuid_module.UUID + + +def test_logical_uuid_arrow_ext_name(): + """Arrow extension name is 'orcapod.uuid', matching logical_type_name.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_arrow_extension_type().extension_name == "orcapod.uuid" + assert lt.get_arrow_extension_type().extension_name == lt.logical_type_name + + +def test_logical_uuid_arrow_ext_storage_type(): + """Arrow extension storage type is pa.large_binary().""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + assert LogicalUUID().get_arrow_extension_type().storage_type == pa.large_binary() + + +def test_logical_uuid_get_arrow_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_arrow_extension_type() is lt.get_arrow_extension_type() + + +def test_logical_uuid_get_polars_extension_type_is_cached(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + assert lt.get_polars_extension_type() is lt.get_polars_extension_type() + + +def test_logical_uuid_round_trip(): + """UUID -> python_to_storage -> storage_to_python -> UUID is identity.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.uuid4() + assert lt.storage_to_python(lt.python_to_storage(u)) == u + + +def test_logical_uuid_python_to_storage_returns_bytes(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result = lt.python_to_storage(u) + assert isinstance(result, bytes) + assert len(result) == 16 + + +def test_logical_uuid_storage_to_python_accepts_bytes(): + """storage_to_python works when storage_value is plain bytes.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + recovered = lt.storage_to_python(u.bytes) + assert recovered == u + + +def test_logical_uuid_registration_does_not_raise(): + """Registering LogicalUUID succeeds and is reachable by both logical and arrow names.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = LogicalTypeRegistry() + lt = LogicalUUID() + registry.register_logical_type(lt) # should NOT raise + assert registry.get_by_logical_name("orcapod.uuid") is lt + assert registry.get_by_arrow_extension_name("orcapod.uuid") is lt + + +# --------------------------------------------------------------------------- +# Arrow and Polars end-to-end round-trip tests +# --------------------------------------------------------------------------- + + +def test_logical_path_arrow_round_trip(): + """Python -> Arrow extension array -> Python via LogicalPath.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [pathlib.Path("/tmp/foo"), pathlib.Path("/home/user/bar.txt")] + storage_vals = [lt.python_to_storage(p) for p in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + recovered = [lt.storage_to_python(v.as_py()) for v in ext_arr.storage] + assert recovered == originals + + +def test_logical_path_polars_round_trip(): + """Python -> Arrow extension array -> Polars series -> Arrow -> Python via LogicalPath.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + + lt = LogicalPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [pathlib.Path("/tmp/foo"), pathlib.Path("/home/user/bar.txt")] + storage_vals = [lt.python_to_storage(p) for p in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + pl_series = pl.from_arrow(ext_arr) + arr_back = pl_series.to_arrow() + recovered = [lt.storage_to_python(v.as_py()) for v in arr_back.storage] + assert recovered == originals + + +def test_logical_upath_arrow_round_trip(): + """Python -> Arrow extension array -> Python via LogicalUPath.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [UPath("s3://bucket/key"), UPath("gs://other/path/file.txt")] + storage_vals = [lt.python_to_storage(p) for p in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + recovered = [lt.storage_to_python(v.as_py()) for v in ext_arr.storage] + assert recovered == originals + + +def test_logical_upath_polars_round_trip(): + """Python -> Arrow extension array -> Polars series -> Arrow -> Python via LogicalUPath.""" + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + lt = LogicalUPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [UPath("s3://bucket/key"), UPath("gs://other/path/file.txt")] + storage_vals = [lt.python_to_storage(p) for p in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + pl_series = pl.from_arrow(ext_arr) + arr_back = pl_series.to_arrow() + recovered = [lt.storage_to_python(v.as_py()) for v in arr_back.storage] + assert recovered == originals + + +def test_logical_uuid_arrow_round_trip(): + """Python -> Arrow extension array -> Python via LogicalUUID.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [uuid_module.UUID("12345678-1234-5678-1234-567812345678"), uuid_module.uuid4()] + storage_vals = [lt.python_to_storage(u) for u in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + recovered = [lt.storage_to_python(v.as_py()) for v in ext_arr.storage] + assert recovered == originals + + +def test_logical_uuid_polars_round_trip(): + """Python -> Arrow extension array -> Polars series -> Arrow -> Python via LogicalUUID.""" + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + lt = LogicalUUID() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [uuid_module.UUID("12345678-1234-5678-1234-567812345678"), uuid_module.uuid4()] + storage_vals = [lt.python_to_storage(u) for u in originals] + arrow_ext = lt.get_arrow_extension_type() + ext_arr = pa.ExtensionArray.from_storage(arrow_ext, pa.array(storage_vals, type=arrow_ext.storage_type)) + + pl_series = pl.from_arrow(ext_arr) + arr_back = pl_series.to_arrow() + recovered = [lt.storage_to_python(v.as_py()) for v in arr_back.storage] + assert recovered == originals + + +# --------------------------------------------------------------------------- +# Default context integration tests +# --------------------------------------------------------------------------- + + +def test_default_context_has_logical_type_registry(): + """DataContext's type_converter has a _logical_type_registry attribute.""" + from orcapod.contexts import get_default_context + + ctx = get_default_context() + assert hasattr(ctx.type_converter, "_logical_type_registry") + assert ctx.type_converter._logical_type_registry is not None + + +def test_default_context_registry_has_logical_path(): + """Default registry returns LogicalPath for 'pathlib.Path'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_logical_name("orcapod.path") + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_lookup_by_python_type_path(): + """Default registry routes pathlib.Path to LogicalPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_python_type(pathlib.Path) + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_lookup_by_arrow_name_path(): + """Default registry routes 'pathlib.Path' arrow ext name to LogicalPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalPath + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_arrow_extension_name("orcapod.path") + assert isinstance(lt, LogicalPath) + + +def test_default_context_registry_has_logical_upath(): + """Default registry returns LogicalUPath for 'upath.UPath'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_logical_name("orcapod.upath") + assert isinstance(lt, LogicalUPath) + + +def test_default_context_registry_lookup_by_python_type_upath(): + """Default registry routes UPath to LogicalUPath.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_python_type(UPath) + assert isinstance(lt, LogicalUPath) + + +def test_default_context_registry_has_logical_uuid(): + """Default registry returns LogicalUUID for 'uuid.UUID'.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_logical_name("orcapod.uuid") + assert isinstance(lt, LogicalUUID) + + +def test_default_context_registry_lookup_by_arrow_name_uuid(): + """Default registry routes 'uuid.UUID' arrow ext name to LogicalUUID.""" + from orcapod.contexts import get_default_context + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + registry = get_default_context().type_converter._logical_type_registry + lt = registry.get_by_arrow_extension_name("orcapod.uuid") + assert isinstance(lt, LogicalUUID) + + +def test_default_type_converter_logical_registry_is_not_none(): + """The default context's type_converter has a non-None _logical_type_registry.""" + from orcapod.contexts import get_default_context + + ctx = get_default_context() + assert ctx.type_converter._logical_type_registry is not None + + +def test_default_context_idempotent_registry(): + """Calling get_default_context() twice returns the same LogicalTypeRegistry instance.""" + from orcapod.contexts import get_default_context + + r1 = get_default_context().type_converter._logical_type_registry + r2 = get_default_context().type_converter._logical_type_registry + assert r1 is r2 + + +# --------------------------------------------------------------------------- +# Top-level orcapod namespace alias tests +# --------------------------------------------------------------------------- + + +def test_orcapod_path_alias_is_pathlib_path(): + """orcapod.Path is the same object as pathlib.Path.""" + import pathlib + + assert orcapod.Path is pathlib.Path + + +def test_orcapod_upath_alias_is_upath_upath(): + """orcapod.UPath is the same object as upath.UPath.""" + from upath import UPath + + assert orcapod.UPath is UPath + + +def test_orcapod_uuid_alias_is_uuid_uuid(): + """orcapod.UUID is the same object as uuid.UUID.""" + import uuid + + assert orcapod.UUID is uuid.UUID + + +def test_orcapod_path_alias_in_all(): + """orcapod.Path appears in orcapod.__all__.""" + assert "Path" in orcapod.__all__ + + +def test_orcapod_upath_alias_in_all(): + """orcapod.UPath appears in orcapod.__all__.""" + assert "UPath" in orcapod.__all__ + + +def test_orcapod_uuid_alias_in_all(): + """orcapod.UUID appears in orcapod.__all__.""" + assert "UUID" in orcapod.__all__ + + +# --------------------------------------------------------------------------- +# Alias round-trip tests: using the stdlib types directly still works +# --------------------------------------------------------------------------- +# These tests verify that orcapod.Path / orcapod.UPath / orcapod.UUID are true +# aliases, not wrappers. Because e.g. orcapod.UUID is uuid.UUID, using +# uuid.UUID directly produces the same orcapod.uuid Arrow extension type, and +# the value recovered from Arrow is a uuid.UUID (i.e. also an orcapod.UUID). +# Each test asserts the identity precondition first so the contract is clear. +# --------------------------------------------------------------------------- + + +def test_pathlib_path_works_via_orcapod_path_alias_arrow_round_trip(): + """pathlib.Path values round-trip through Arrow with the orcapod.path extension type. + + This test is only valid because orcapod.Path is pathlib.Path — they are the same + object. Using pathlib.Path directly (rather than orcapod.Path) produces the same + Arrow extension type (``"orcapod.path"``), and the recovered value is a + pathlib.Path (i.e. orcapod.Path). + """ + from orcapod.extension_types.builtin_logical_types import LogicalPath + + # Precondition: test is only meaningful if orcapod.Path is pathlib.Path + assert orcapod.Path is pathlib.Path + + lt = LogicalPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + # Create value using stdlib pathlib directly (not orcapod.Path) + p = pathlib.Path("/tmp/alias_test/foo.txt") + + # Registry can find LogicalPath via pathlib.Path since orcapod.Path is pathlib.Path + found = registry.get_by_python_type(pathlib.Path) + assert found is lt + + # Saving to Arrow produces "orcapod.path" extension type + storage_val = lt.python_to_storage(p) + arrow_ext = lt.get_arrow_extension_type() + assert arrow_ext.extension_name == "orcapod.path" + ext_arr = pa.ExtensionArray.from_storage( + arrow_ext, pa.array([storage_val], type=arrow_ext.storage_type) + ) + + # Recovered value is a pathlib.Path (which is orcapod.Path) + recovered = lt.storage_to_python(ext_arr.storage[0].as_py()) + assert recovered == p + assert isinstance(recovered, orcapod.Path) # valid because orcapod.Path is pathlib.Path + assert isinstance(recovered, pathlib.Path) + + +def test_upath_upath_works_via_orcapod_upath_alias_arrow_round_trip(): + """upath.UPath values round-trip through Arrow with the orcapod.upath extension type. + + This test is only valid because orcapod.UPath is upath.UPath — they are the same + object. Using upath.UPath directly (rather than orcapod.UPath) produces the same + Arrow extension type (``"orcapod.upath"``), and the recovered value is a + upath.UPath (i.e. orcapod.UPath). + """ + from orcapod.extension_types.builtin_logical_types import LogicalUPath + + # Precondition: test is only meaningful if orcapod.UPath is upath.UPath + assert orcapod.UPath is UPath + + lt = LogicalUPath() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + # Create value using upath directly (not orcapod.UPath) + up = UPath("s3://bucket/alias_test/key.txt") + + # Registry can find LogicalUPath via UPath since orcapod.UPath is upath.UPath + found = registry.get_by_python_type(UPath) + assert found is lt + + # Saving to Arrow produces "orcapod.upath" extension type + storage_val = lt.python_to_storage(up) + arrow_ext = lt.get_arrow_extension_type() + assert arrow_ext.extension_name == "orcapod.upath" + ext_arr = pa.ExtensionArray.from_storage( + arrow_ext, pa.array([storage_val], type=arrow_ext.storage_type) + ) + + # Recovered value is a upath.UPath (which is orcapod.UPath) + recovered = lt.storage_to_python(ext_arr.storage[0].as_py()) + assert recovered == up + assert isinstance(recovered, orcapod.UPath) # valid because orcapod.UPath is upath.UPath + assert isinstance(recovered, UPath) + + +def test_uuid_uuid_works_via_orcapod_uuid_alias_arrow_round_trip(): + """uuid.UUID values round-trip through Arrow with the orcapod.uuid extension type. + + This test is only valid because orcapod.UUID is uuid.UUID — they are the same + object. Using uuid.UUID directly (rather than orcapod.UUID) produces the same + Arrow extension type (``"orcapod.uuid"``), and the recovered value is a + uuid.UUID (i.e. orcapod.UUID). + """ + from orcapod.extension_types.builtin_logical_types import LogicalUUID + + # Precondition: test is only meaningful if orcapod.UUID is uuid.UUID + assert orcapod.UUID is uuid_module.UUID + + lt = LogicalUUID() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + # Create value using stdlib uuid directly (not orcapod.UUID) + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + + # Registry can find LogicalUUID via uuid.UUID since orcapod.UUID is uuid.UUID + found = registry.get_by_python_type(uuid_module.UUID) + assert found is lt + + # Saving to Arrow produces "orcapod.uuid" extension type + storage_val = lt.python_to_storage(u) + arrow_ext = lt.get_arrow_extension_type() + assert arrow_ext.extension_name == "orcapod.uuid" + ext_arr = pa.ExtensionArray.from_storage( + arrow_ext, pa.array([storage_val], type=arrow_ext.storage_type) + ) + + # Recovered value is a uuid.UUID (which is orcapod.UUID) + recovered = lt.storage_to_python(ext_arr.storage[0].as_py()) + assert recovered == u + assert isinstance(recovered, orcapod.UUID) # valid because orcapod.UUID is uuid.UUID + assert isinstance(recovered, uuid_module.UUID) + + +# --------------------------------------------------------------------------- +# Converter param acceptance tests (Task 2 — PLT-1705) +# --------------------------------------------------------------------------- + + +def test_logical_path_python_to_storage_accepts_converter(): + """python_to_storage now accepts a converter param (ignored).""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + lt = LogicalPath() + result = lt.python_to_storage(pathlib.Path("/tmp/foo"), converter=None) + assert result == "/tmp/foo" + + +def test_logical_path_storage_to_python_accepts_converter(): + """storage_to_python now accepts a converter param (ignored).""" + from orcapod.extension_types.builtin_logical_types import LogicalPath + lt = LogicalPath() + result = lt.storage_to_python("/tmp/foo", converter=None) + assert result == pathlib.Path("/tmp/foo") + + +def test_logical_uuid_python_to_storage_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result = lt.python_to_storage(u, converter=None) + assert result == u.bytes + + +def test_logical_uuid_storage_to_python_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUUID + lt = LogicalUUID() + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result = lt.storage_to_python(u.bytes, converter=None) + assert result == u + + +def test_logical_upath_python_to_storage_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + lt = LogicalUPath() + result = lt.python_to_storage(UPath("s3://bucket/key"), converter=None) + assert result == "s3://bucket/key" + + +def test_logical_upath_storage_to_python_accepts_converter(): + from orcapod.extension_types.builtin_logical_types import LogicalUPath + lt = LogicalUPath() + result = lt.storage_to_python("s3://bucket/key", converter=None) + assert isinstance(result, UPath) diff --git a/tests/test_extension_types/test_cache_behavior.py b/tests/test_extension_types/test_cache_behavior.py new file mode 100644 index 00000000..efbb77e2 --- /dev/null +++ b/tests/test_extension_types/test_cache_behavior.py @@ -0,0 +1,114 @@ +"""Integration tests for per-process extension type cache behaviour. + +The ``LogicalTypeRegistry`` stores registered types in an in-memory dict keyed +by Arrow extension name. ``register_discovered_extensions`` skips the factory +call (``reconstruct_from_arrow``) when the extension name is already present in +the registry — this is the "cache hit" path. + +Two tests: + +1. ``test_cache_populated_after_first_read`` — verifies the type is absent from + a fresh converter's registry before reading a Parquet file, and present after. + +2. ``test_factory_not_called_on_second_read`` — verifies that ``reconstruct_from_arrow`` + is called exactly once (first read) and zero additional times on the second + read of the same file. +""" +from __future__ import annotations + +import dataclasses +from unittest.mock import patch + +import pyarrow.parquet as pq + +from orcapod.contexts import create_registry +from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + +# Module-level dataclass — local classes cannot be reconstructed from FQCN. + +@dataclasses.dataclass +class _CachePoint: + x: int + y: int + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _fresh_converter(): + """Return a fresh UniversalTypeConverter from a new registry instance. + + Uses ``create_registry()`` instead of ``get_default_context()`` to avoid + cross-test contamination through the global singleton cache. + """ + return create_registry().get_context().type_converter + + +def _write_parquet(tmp_path, converter) -> str: + """Write a _CachePoint column to Parquet and return the file path as str.""" + converter.register_python_class(_CachePoint) + arrow_schema = converter.python_schema_to_arrow_schema({"point": _CachePoint}) + rows = [{"point": _CachePoint(x=1, y=2)}] + table = converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + parquet_path = tmp_path / "cache_test.parquet" + pq.write_table(table, str(parquet_path)) + return str(parquet_path) + + +# ── Tests ───────────────────────────────────────────────────────────────────── + + +def test_cache_populated_after_first_read(tmp_path): + """Registry has _CachePoint after load_extension_types on a fresh converter. + + Before reading: the fresh converter's registry does not know about _CachePoint. + After reading: register_discovered_extensions triggers reconstruct_from_arrow + which registers _CachePoint, populating the cache. + """ + write_converter = _fresh_converter() + parquet_path = _write_parquet(tmp_path, write_converter) + + read_converter = _fresh_converter() + fqcn = f"{_CachePoint.__module__}.{_CachePoint.__qualname__}" + + # Before read: not registered + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn) is None + + read_converter.load_extension_types(pq.read_table(parquet_path)) + + # After read: registered (cache populated) + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn) is not None + + +def test_factory_not_called_on_second_read(tmp_path): + """reconstruct_from_arrow called once on first read, zero times on second read. + + On first read, register_discovered_extensions finds _CachePoint's extension + name in the schema, dispatches to the factory (call count = 1), and stores + the result in the registry. + + On second read, register_discovered_extensions finds the extension name already + in the registry and short-circuits — the factory is not called again + (call count remains 1). + """ + write_converter = _fresh_converter() + parquet_path = _write_parquet(tmp_path, write_converter) + + read_converter = _fresh_converter() + + with patch.object( + DataclassLogicalTypeFactory, + "reconstruct_from_arrow", + autospec=True, + wraps=DataclassLogicalTypeFactory.reconstruct_from_arrow, + ) as spy: + # First read: factory is called once + read_converter.load_extension_types(pq.read_table(parquet_path)) + assert spy.call_count == 1, f"Expected 1 factory call, got {spy.call_count}" + + # Second read on the same file: registry hit — factory not called again + read_converter.load_extension_types(pq.read_table(parquet_path)) + assert spy.call_count == 1, ( + f"Expected still 1 factory call after second read, got {spy.call_count}" + ) diff --git a/tests/test_extension_types/test_database_hooks.py b/tests/test_extension_types/test_database_hooks.py new file mode 100644 index 00000000..12403203 --- /dev/null +++ b/tests/test_extension_types/test_database_hooks.py @@ -0,0 +1,272 @@ +"""Tests for register_discovered_extensions in database_hooks.""" + +from __future__ import annotations + +import json +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + """Unique Arrow extension name to avoid cross-test global-registry collisions.""" + return f"test.hooks.{uuid.uuid4().hex[:8]}" + + +def _make_ext_schema( + arrow_name: str, + metadata: bytes | None = None, + storage: pa.DataType | None = None, +) -> pa.Schema: + """Build a ``pa.Schema`` with one extension-typed field using ``make_arrow_extension_type``. + + Only call this when you have control over the metadata content — the resulting + field's type is an in-memory ``pa.ExtensionType`` instance, not raw field metadata. + """ + _storage = storage or pa.large_utf8() + ext_cls = make_arrow_extension_type(arrow_name, _storage, metadata=metadata) + return pa.schema([pa.field("col", ext_cls())]) + + +def _make_field_metadata_schema( + arrow_name: str, + metadata: bytes, + storage: pa.DataType | None = None, +) -> pa.Schema: + """Build a schema where the extension is described by raw Arrow field metadata. + + This simulates a Parquet/IPC read where the extension type was not registered + in the current process, so ``field.type`` is a plain Arrow storage type rather + than a ``pa.ExtensionType`` instance. + """ + _storage = storage or pa.large_utf8() + field = pa.field("col", _storage).with_metadata({ + b"ARROW:extension:name": arrow_name.encode(), + b"ARROW:extension:metadata": metadata, + }) + return pa.schema([field]) + + +def _make_stub_factory(): + """Return a minimal LogicalTypeFactory stub whose calls are recorded. + + The factory auto-creates a fresh ``LogicalType`` stub keyed by arrow name. + Registering this factory in a registry causes it to also register a Polars + extension type, which requires the Arrow ext type to be in PyArrow's global + registry. To avoid cross-test collisions, each test uses a unique arrow name. + """ + class _Factory: + def __init__(self): + self.calls: list[tuple] = [] + + def supports_class(self, python_type): + return False + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata, converter): + import polars as pl + from orcapod.extension_types.registry import make_arrow_extension_type + + self.calls.append((arrow_extension_name, storage_type, metadata)) + + _name = arrow_extension_name + _arrow_cls = make_arrow_extension_type(_name, storage_type) + _pl_storage = pl.from_arrow(pa.array([], type=storage_type)).dtype + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(_name, _pl_storage, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _StubLT: + @property + def logical_type_name(self): + return _name + @property + def python_type(self): + return str + def get_arrow_extension_type(self): + return _arrow_cls() + def get_polars_extension_type(self): + return _PolarsExt() + def python_to_storage(self, value, converter=None): + return str(value) + def storage_to_python(self, storage_value, converter=None): + return storage_value + + return _StubLT() + + def create_for_python_type(self, python_type, converter): + pass + + return _Factory() + + +def _make_converter(factory=None, category=None) -> UniversalTypeConverter: + """Make a UniversalTypeConverter with an optional factory registered.""" + registry = LogicalTypeRegistry() + converter = UniversalTypeConverter(logical_type_registry=registry) + if factory is not None and category is not None: + converter.register_logical_type_factory(factory, category=category) + return converter + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + +@pytest.fixture +def fresh_converter(): + """A fresh, isolated converter (with empty registry) for each test.""" + return _make_converter() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_no_extension_types_is_noop(fresh_converter): + """Schema with only primitives — register_discovered_extensions returns without touching registry.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + schema = pa.schema([ + pa.field("id", pa.int64()), + pa.field("name", pa.large_utf8()), + ]) + register_discovered_extensions(fresh_converter, schema) + # fresh registry is empty — no error means no spurious lookup was triggered + assert fresh_converter._logical_type_registry.get_by_arrow_extension_name("anything") is None + + +def test_known_type_is_registered(): + """Schema with one extension type whose factory is registered — type is registered after call.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + factory = _make_stub_factory() + converter = _make_converter(factory=factory, category="TestCat") + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema = _make_ext_schema(arrow_name, metadata=metadata_bytes) + + register_discovered_extensions(converter, schema) + + assert converter._logical_type_registry.get_by_arrow_extension_name(arrow_name) is not None + assert len(factory.calls) == 1 + + +def test_already_registered_is_skipped(): + """Calling register_discovered_extensions twice does not raise and factory is called once.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + factory = _make_stub_factory() + converter = _make_converter(factory=factory, category="TestCat") + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema = _make_ext_schema(arrow_name, metadata=metadata_bytes) + + register_discovered_extensions(converter, schema) + register_discovered_extensions(converter, schema) # second call + + assert len(factory.calls) == 1 # factory invoked exactly once + + +def test_none_metadata_already_registered_noop(): + """Extension type with None metadata that IS already in the registry — silent no-op.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + factory = _make_stub_factory() + converter = _make_converter(factory=factory, category="TestCat") + + # First: register via metadata so it ends up in the registry. + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + schema_with_meta = _make_ext_schema(arrow_name, metadata=metadata_bytes) + register_discovered_extensions(converter, schema_with_meta) + + # Now: same arrow name but with no metadata (simulates reading the schema without + # metadata — e.g. after an IPC round-trip where the type is now registered in-process). + schema_no_meta = _make_ext_schema(arrow_name, metadata=None) + register_discovered_extensions(converter, schema_no_meta) # should NOT raise + + +def test_none_metadata_not_registered_raises(): + """Unregistered extension type with None metadata raises ValueError.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + converter = _make_converter() + schema = _make_ext_schema(arrow_name, metadata=None) + + with pytest.raises(ValueError, match="Pre-register them explicitly"): + register_discovered_extensions(converter, schema) + + +def test_metadata_not_json_raises(): + """Unregistered extension type with non-JSON metadata bytes raises ValueError.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + converter = _make_converter() + schema = _make_field_metadata_schema(arrow_name, metadata=b"not-json!") + + with pytest.raises(ValueError, match="not valid UTF-8 JSON"): + register_discovered_extensions(converter, schema) + + +def test_metadata_json_missing_category_raises(): + """Unregistered extension type with valid JSON but no 'category' key raises ValueError.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + converter = _make_converter() + schema = _make_field_metadata_schema( + arrow_name, metadata=json.dumps({"version": 1}).encode() + ) + + with pytest.raises(ValueError, match='"category"'): + register_discovered_extensions(converter, schema) + + +def test_unknown_metadata_raises(): + """Unregistered extension type with valid JSON and 'category' but no matching factory raises ValueError.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + converter = _make_converter() + schema = _make_field_metadata_schema( + arrow_name, metadata=json.dumps({"category": "NoSuchFactory"}).encode() + ) + + with pytest.raises(ValueError, match="NoSuchFactory"): + register_discovered_extensions(converter, schema) + + +def test_nested_extension_type(): + """Extension type inside a struct column is discovered and registered.""" + from orcapod.extension_types.database_hooks import register_discovered_extensions + + arrow_name = _unique_name() + factory = _make_stub_factory() + converter = _make_converter(factory=factory, category="TestCat") + + metadata_bytes = json.dumps({"category": "TestCat"}).encode() + inner_ext_cls = make_arrow_extension_type(arrow_name, pa.large_utf8(), metadata=metadata_bytes) + + struct_type = pa.struct([pa.field("inner", inner_ext_cls())]) + schema = pa.schema([pa.field("outer", struct_type)]) + + register_discovered_extensions(converter, schema) + + assert converter._logical_type_registry.get_by_arrow_extension_name(arrow_name) is not None + assert len(factory.calls) == 1 diff --git a/tests/test_extension_types/test_dataclass_logical_type_factory.py b/tests/test_extension_types/test_dataclass_logical_type_factory.py new file mode 100644 index 00000000..57607efe --- /dev/null +++ b/tests/test_extension_types/test_dataclass_logical_type_factory.py @@ -0,0 +1,580 @@ +"""Tests for DataclassLogicalType and DataclassLogicalTypeFactory.""" + +from __future__ import annotations + +import dataclasses +import uuid as _uuid_module +from typing import Any + +import pyarrow as pa +import pytest + + +# ── Helpers ───────────────────────────────────────────────────────────────── + +class _StubConverter: + """Minimal converter stub for DataclassLogicalType tests.""" + + def python_to_storage(self, value, annotation): + if annotation is str: + return str(value) + if annotation is int: + return int(value) + return value + + def storage_to_python(self, storage_value, annotation): + if annotation is str: + return str(storage_value) + if annotation is int: + return int(storage_value) + return storage_value + + def register_python_class(self, annotation): + if annotation is str: + return pa.large_string() + if annotation is int: + return pa.int64() + raise ValueError(f"No mapping for {annotation}") + + +# ── DataclassLogicalType tests ─────────────────────────────────────────────── + +def test_dataclass_logical_type_is_importable(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + assert DataclassLogicalType is not None + + +def test_dataclass_logical_type_protocol_conformance(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + from orcapod.extension_types.protocols import LogicalTypeProtocol + + @dataclasses.dataclass + class _MyDC: + name: str + count: int + + storage = pa.struct([pa.field("name", pa.large_string()), pa.field("count", pa.int64())]) + field_annotations = [("name", str), ("count", int)] + lt = DataclassLogicalType( + logical_name="tests.MyDC", + python_type=_MyDC, + storage_type=storage, + field_annotations=field_annotations, + ) + assert isinstance(lt, LogicalTypeProtocol) + + +def test_dataclass_logical_type_python_to_storage(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _Point: + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = DataclassLogicalType("tests.Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.python_to_storage(_Point(x=3, y=7), converter) + assert result == {"x": 3, "y": 7} + + +def test_dataclass_logical_type_storage_to_python(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _Point: + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = DataclassLogicalType("tests.Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.storage_to_python({"x": 3, "y": 7}, converter) + assert isinstance(result, _Point) + assert result.x == 3 + assert result.y == 7 + + +def test_dataclass_logical_type_logical_type_name(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _Foo: + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = DataclassLogicalType("mymod.Foo", _Foo, storage, [("val", str)]) + assert lt.logical_type_name == "mymod.Foo" + + +def test_dataclass_logical_type_python_type(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _Bar: + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = DataclassLogicalType("mymod.Bar", _Bar, storage, [("val", str)]) + assert lt.python_type is _Bar + + +# ── DataclassLogicalTypeFactory helpers ────────────────────────────────────────── + +def _make_full_converter(): + """Make a UniversalTypeConverter with builtin types + DataclassLogicalTypeFactory.""" + from orcapod.extension_types.builtin_logical_types import LogicalPath, LogicalUUID, LogicalUPath + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory, DATACLASS_CATEGORY + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + registry = LogicalTypeRegistry(logical_types=[LogicalPath(), LogicalUUID(), LogicalUPath()]) + factory = DataclassLogicalTypeFactory() + registry.register_logical_type_factory(factory, category=DATACLASS_CATEGORY, python_bases=[object]) + return UniversalTypeConverter(logical_type_registry=registry) + + +# ── DataclassLogicalTypeFactory write-path tests ───────────────────────────────── + +def test_factory_supports_class_dataclass(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + @dataclasses.dataclass + class _Dummy: + x: int + + factory = DataclassLogicalTypeFactory() + assert factory.supports_class(_Dummy) is True + + +def test_factory_supports_class_non_dataclass(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + factory = DataclassLogicalTypeFactory() + assert factory.supports_class(str) is False + assert factory.supports_class(int) is False + + +@dataclasses.dataclass +class _Flat: + name: str + count: int + + +@dataclasses.dataclass +class _WithUUID: + id: _uuid_module.UUID + label: str + + +@dataclasses.dataclass +class _WithList: + tags: list[str] + count: int + + +@dataclasses.dataclass +class _WithDict: + meta: dict[str, int] + + +@dataclasses.dataclass +class _InnerForRegistrationTest: + """Module-level inner dataclass for registration completeness test.""" + value: int + + +@dataclasses.dataclass +class _OuterForRegistrationTest: + """Module-level outer dataclass for registration completeness test.""" + inner: _InnerForRegistrationTest + label: str + + +# ── Module-level dataclasses for list[dataclass[dataclass]] round-trip test ── + +@dataclasses.dataclass +class _ListItemDC: + """Inner dataclass used as element type in list[_ListItemDC] field.""" + x: int + y: int + + +@dataclasses.dataclass +class _ListContainerDC: + """Outer dataclass with a list[_ListItemDC] field.""" + items: list[_ListItemDC] + label: str + + +def test_factory_create_flat_dataclass(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory, DataclassLogicalType + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_Flat, converter=converter) + + assert isinstance(lt, DataclassLogicalType) + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_struct(storage) + assert storage.field("name").type == pa.large_string() + assert storage.field("count").type == pa.int64() + + +def test_factory_create_dataclass_with_uuid_field(): + """UUID field → plain storage type (large_binary) in the struct, not extension type. + + ``pa.Table.from_pylist`` (and Polars dtype inference) cannot handle a struct + whose fields are ``pa.ExtensionType`` nodes. ``DataclassLogicalTypeFactory`` strips + extension types from struct field types so that Arrow array construction works. + The UUID's extension type (``orcapod.uuid``) is still registered and used for + value conversion; only the struct field schema uses the stripped storage type. + """ + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithUUID, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + id_field_type = storage.field("id").type + # Stripped to plain storage type — NOT an extension type in the struct. + assert id_field_type == pa.large_binary() + assert not isinstance(id_field_type, pa.ExtensionType) + + +def test_factory_create_dataclass_with_list_field(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithList, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_large_list(storage.field("tags").type) + assert storage.field("tags").type.value_type == pa.large_string() + + +def test_factory_create_dataclass_with_dict_field(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_WithDict, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + meta_type = storage.field("meta").type + assert pa.types.is_large_list(meta_type) + assert pa.types.is_struct(meta_type.value_type) + field_names = {meta_type.value_type.field(i).name for i in range(meta_type.value_type.num_fields)} + assert field_names == {"key", "value"} + + +def test_factory_rejects_local_class(): + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + def _make_local(): + @dataclasses.dataclass + class _Local: + x: int + return _Local + + LocalClass = _make_local() + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + with pytest.raises(ValueError, match="local"): + factory.create_for_python_type(LocalClass, converter=converter) + + +def test_register_python_class_dispatches_to_dataclass_factory(): + """register_python_class on a dataclass triggers DataclassLogicalTypeFactory.""" + converter = _make_full_converter() + + # For this test, use UUID as a proxy (already registered as built-in). + result = converter.register_python_class(_uuid_module.UUID) + assert isinstance(result, pa.ExtensionType) + assert result.extension_name == "orcapod.uuid" + + +# ── Module-level dataclasses for round-trip tests ──────────────────────────── + +@dataclasses.dataclass +class _RoundTripPoint: + """Module-level dataclass for round-trip testing.""" + x: int + y: int + + +@dataclasses.dataclass +class _RoundTripRecord: + """Module-level dataclass with a UUID field.""" + record_id: _uuid_module.UUID + label: str + + +# ── Read-path tests ─────────────────────────────────────────────────────────── + +def test_factory_reconstruct_from_arrow(): + """reconstruct_from_arrow rebuilds the logical type from the Arrow struct.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory, DataclassLogicalType + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + metadata = {"category": "orcapod.dataclass"} + fqcn = f"{_RoundTripPoint.__module__}.{_RoundTripPoint.__qualname__}" + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.reconstruct_from_arrow(fqcn, storage, metadata, converter=converter) + + assert isinstance(lt, DataclassLogicalType) + assert lt.python_type is _RoundTripPoint + assert lt.logical_type_name == fqcn + + +def test_factory_reconstruct_from_arrow_invalid_fqcn(): + """ImportError if the FQCN cannot be resolved.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + storage = pa.struct([pa.field("x", pa.int64())]) + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + + with pytest.raises(ImportError): + factory.reconstruct_from_arrow( + "nonexistent.module.NoSuchClass", storage, {"category": "orcapod.dataclass"}, converter + ) + + +def test_reconstruct_from_arrow_registers_nested_types(): + """reconstruct_from_arrow for Outer must register Inner as a side effect.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + # Build the storage type for _OuterForRegistrationTest manually (as it would come + # from Parquet): outer struct with an inner struct field (Inner is stored as a struct, + # NOT as an extension type inside the struct field — that's the ET1 constraint). + inner_storage = pa.struct([pa.field("value", pa.int64())]) + outer_storage = pa.struct([ + pa.field("inner", inner_storage), + pa.field("label", pa.large_string()), + ]) + outer_fqcn = f"{_OuterForRegistrationTest.__module__}.{_OuterForRegistrationTest.__qualname__}" + + factory = DataclassLogicalTypeFactory() + converter = _make_full_converter() + + # Inner is NOT pre-registered + assert converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is None + + # reconstruct_from_arrow for Outer should trigger registration of Inner as a side effect + lt = factory.reconstruct_from_arrow(outer_fqcn, outer_storage, {"category": "orcapod.dataclass"}, converter) + + # Inner must now be registered + assert converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is not None + + +def test_dataclass_python_to_storage_round_trip(): + """python_to_storage → storage_to_python returns an equivalent dataclass.""" + converter = _make_full_converter() + + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + factory = DataclassLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripPoint, converter=converter) + converter.register_logical_type(lt) + + point = _RoundTripPoint(x=10, y=20) + storage_value = lt.python_to_storage(point, converter) + assert storage_value == {"x": 10, "y": 20} + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripPoint) + assert reconstructed.x == 10 + assert reconstructed.y == 20 + + +def test_dataclass_with_uuid_round_trip(): + """Round-trip a dataclass with a UUID field through python_to_storage / storage_to_python.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalTypeFactory + + converter = _make_full_converter() + factory = DataclassLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripRecord, converter=converter) + converter.register_logical_type(lt) + + u = _uuid_module.UUID("12345678-1234-5678-1234-567812345678") + record = _RoundTripRecord(record_id=u, label="hello") + + storage_value = lt.python_to_storage(record, converter) + assert storage_value["label"] == "hello" + # UUID stored as bytes + assert storage_value["record_id"] == u.bytes + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripRecord) + assert reconstructed.record_id == u + assert reconstructed.label == "hello" + + +# ── _import_from_fqcn nested class tests ───────────────────────────────────── + +@dataclasses.dataclass +class _OuterForNestedTest: + """Module-level outer class for testing nested-class FQCN import.""" + + @dataclasses.dataclass + class Inner: + x: int + y: str + + +def test_import_from_fqcn_nested_class(): + """_import_from_fqcn resolves module-level nested dataclasses via attribute walk.""" + from orcapod.extension_types.dataclass_logical_type_factory import _import_from_fqcn + + # _OuterForNestedTest.Inner lives in this test module; its FQCN uses '.' for nesting + module = _OuterForNestedTest.__module__ + outer_qualname = _OuterForNestedTest.__qualname__ + inner_qualname = _OuterForNestedTest.Inner.__qualname__ # e.g. "_OuterForNestedTest.Inner" + + fqcn = f"{module}.{inner_qualname}" + cls = _import_from_fqcn(fqcn) + assert cls is _OuterForNestedTest.Inner + assert dataclasses.is_dataclass(cls) + + +def test_python_to_storage_raises_when_converter_none(): + """DataclassLogicalType.python_to_storage raises ValueError when converter is None.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _DC: + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = DataclassLogicalType("mymod._DC", _DC, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.python_to_storage(_DC(x=1), None) + + +def test_storage_to_python_raises_when_converter_none(): + """DataclassLogicalType.storage_to_python raises ValueError when converter is None.""" + from orcapod.extension_types.dataclass_logical_type_factory import DataclassLogicalType + + @dataclasses.dataclass + class _DC: + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = DataclassLogicalType("mymod._DC2", _DC, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.storage_to_python({"x": 1}, None) + + +def test_nested_dataclass_parquet_roundtrip(tmp_path): + """Fresh-process Parquet round-trip for a two-level nested dataclass. + + Verifies that register_discovered_extensions triggers the chain: + register_arrow_extension("Outer") -> reconstruct_from_arrow + -> register_python_class(Inner) -> registers Inner + so that storage_to_python can reconstruct the full nested object. + """ + import pyarrow.parquet as pq + from orcapod.extension_types.database_hooks import register_discovered_extensions, apply_extension_types + + # ── Write path ─────────────────────────────────────────────────────────── + write_converter = _make_full_converter() + + inner = _InnerForRegistrationTest(value=42) + outer = _OuterForRegistrationTest(inner=inner, label="hello") + + # Register Outer (which also registers Inner via create_for_python_type) + write_converter.register_python_class(_OuterForRegistrationTest) + + # Serialise: python_schema_to_arrow_schema gives the column-level Arrow schema + # (with extension types at the top level); python_dicts_to_arrow_table converts rows. + arrow_schema = write_converter.python_schema_to_arrow_schema({"item": _OuterForRegistrationTest}) + rows = [{"item": outer}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "nested.parquet" + pq.write_table(table, parquet_path) + + # ── Read path (fresh converter — neither Inner nor Outer pre-registered) ── + read_converter = _make_full_converter() + read_table = pq.read_table(parquet_path) + + # register_discovered_extensions triggers: Outer -> reconstruct_from_arrow + # -> register_python_class(Inner) -> registers Inner + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + # Both types must now be registered + assert read_converter._logical_type_registry.get_by_python_type(_OuterForRegistrationTest) is not None + assert read_converter._logical_type_registry.get_by_python_type(_InnerForRegistrationTest) is not None + + # Convert back to Python and verify full nested object + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + reconstructed = rows_out[0]["item"] + assert isinstance(reconstructed, _OuterForRegistrationTest) + assert isinstance(reconstructed.inner, _InnerForRegistrationTest) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" + + +@pytest.mark.xfail( + reason=( + "list[T] where T is a logical type (e.g. a dataclass) is not yet supported. " + "Arrow cannot preserve extension types inside list value fields (ET2 in " + "DESIGN_ISSUES.md). Planned in PLT-1732 (ListLogicalType / StructLogicalType)." + ), + raises=ValueError, + strict=True, +) +def test_list_of_nested_dataclass_parquet_roundtrip(tmp_path): + """Parquet round-trip for a dataclass whose field is list[AnotherDataclass]. + + This test documents the PLT-1732 gap: registering a dataclass that contains a + list[T] field where T is itself a logical type currently raises ValueError because + Arrow cannot represent extension types inside list value fields. + + Once PLT-1732 (ListLogicalType) is implemented, this test should pass and the + xfail marker should be removed. + """ + import pyarrow.parquet as pq + from orcapod.extension_types.database_hooks import register_discovered_extensions, apply_extension_types + + # ── Write path ─────────────────────────────────────────────────────────── + write_converter = _make_full_converter() + + items = [_ListItemDC(x=1, y=2), _ListItemDC(x=3, y=4)] + container = _ListContainerDC(items=items, label="test") + + # This raises ValueError currently: list[_ListItemDC] contains a logical type + # (_ListItemDC is a dataclass → extension type) in a list value field position. + write_converter.register_python_class(_ListContainerDC) + + arrow_schema = write_converter.python_schema_to_arrow_schema({"record": _ListContainerDC}) + rows = [{"record": container}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "list_nested.parquet" + pq.write_table(table, parquet_path) + + # ── Read path (fresh converter) ────────────────────────────────────────── + read_converter = _make_full_converter() + read_table = pq.read_table(parquet_path) + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + reconstructed = rows_out[0]["record"] + assert isinstance(reconstructed, _ListContainerDC) + assert len(reconstructed.items) == 2 + assert isinstance(reconstructed.items[0], _ListItemDC) + assert reconstructed.items[0].x == 1 + assert reconstructed.items[1].y == 4 + assert reconstructed.label == "test" diff --git a/tests/test_extension_types/test_default_context_factories.py b/tests/test_extension_types/test_default_context_factories.py new file mode 100644 index 00000000..ed914b84 --- /dev/null +++ b/tests/test_extension_types/test_default_context_factories.py @@ -0,0 +1,163 @@ +"""Tests for LogicalTypeRegistry factories parameter and default context factory wiring.""" + +from __future__ import annotations + +import dataclasses + +import pyarrow as pa +import pyarrow.parquet as pq +from pydantic import BaseModel + +from orcapod.contexts import create_registry +from orcapod.extension_types.dataclass_logical_type_factory import ( + DataclassLogicalTypeFactory, + DATACLASS_CATEGORY, +) +from orcapod.extension_types.pydantic_logical_type_factory import ( + PydanticLogicalTypeFactory, + PYDANTIC_CATEGORY, +) +from orcapod.extension_types.registry import LogicalTypeRegistry + + +# ── Module-level dataclasses (local classes cannot be registered) ──────────── + +@dataclasses.dataclass +class _SimplePoint: + """Minimal dataclass used as a test fixture.""" + x: int + y: int + + +class _SimpleModel(BaseModel): + """Minimal pydantic model used as a test fixture.""" + name: str + score: float + + +# ── Registry constructor unit tests ───────────────────────────────────────── + +def test_registry_factories_param_registers_category(): + """factories param registers the factory under the given category.""" + factory = DataclassLogicalTypeFactory() + registry = LogicalTypeRegistry( + factories=[{"factory": factory, "category": DATACLASS_CATEGORY, "python_bases": [object]}] + ) + assert registry._category_factories.get(DATACLASS_CATEGORY) is factory + + +def test_registry_factories_param_registers_python_base(): + """factories param registers the factory under each python_base.""" + factory = DataclassLogicalTypeFactory() + registry = LogicalTypeRegistry( + factories=[{"factory": factory, "category": DATACLASS_CATEGORY, "python_bases": [object]}] + ) + assert registry._python_class_factories.get(object) is factory + + +def test_registry_factories_param_empty_list_is_noop(): + """factories=[] constructs successfully with no registered factories.""" + registry = LogicalTypeRegistry(factories=[]) + assert registry._category_factories == {} + assert registry._python_class_factories == {} + + +def test_registry_factories_param_none_is_noop(): + """factories=None (default) constructs successfully.""" + registry = LogicalTypeRegistry(factories=None) + assert registry._category_factories == {} + assert registry._python_class_factories == {} + + +# ── Default context integration tests ──────────────────────────────────────── +# +# All tests use create_registry().get_context() — NOT get_default_context() — +# to avoid cross-test contamination via the global singleton cache. + + +def test_default_context_has_dataclass_factory(): + """Default context registers DataclassLogicalTypeFactory under orcapod.dataclass.""" + ctx = create_registry().get_context() + registry = ctx.type_converter._logical_type_registry + factory = registry._category_factories.get(DATACLASS_CATEGORY) + assert isinstance(factory, DataclassLogicalTypeFactory) + + +def test_default_context_has_pydantic_factory(): + """Default context registers PydanticLogicalTypeFactory under orcapod.pydantic.""" + ctx = create_registry().get_context() + registry = ctx.type_converter._logical_type_registry + factory = registry._category_factories.get(PYDANTIC_CATEGORY) + assert isinstance(factory, PydanticLogicalTypeFactory) + + +# ── Auto-registration tests ─────────────────────────────────────────────────── + + +def test_default_context_dataclass_auto_registered_on_use(): + """register_python_class on a dataclass works zero-setup via the default context.""" + converter = create_registry().get_context().type_converter + arrow_type = converter.register_python_class(_SimplePoint) + assert isinstance(arrow_type, pa.ExtensionType) + fqcn = f"{_SimplePoint.__module__}.{_SimplePoint.__qualname__}" + assert arrow_type.extension_name == fqcn + + +def test_default_context_pydantic_auto_registered_on_use(): + """register_python_class on a pydantic model works zero-setup via the default context.""" + converter = create_registry().get_context().type_converter + arrow_type = converter.register_python_class(_SimpleModel) + assert isinstance(arrow_type, pa.ExtensionType) + fqcn = f"{_SimpleModel.__module__}.{_SimpleModel.__qualname__}" + assert arrow_type.extension_name == fqcn + + +# ── Parquet round-trip tests ───────────────────────────────────────────────── + + +def test_default_context_dataclass_parquet_roundtrip(tmp_path): + """Dataclass round-trips through Parquet with no manual factory registration.""" + # Write path — fresh context, no manual factory setup + write_converter = create_registry().get_context().type_converter + write_converter.register_python_class(_SimplePoint) + arrow_schema = write_converter.python_schema_to_arrow_schema({"point": _SimplePoint}) + rows = [{"point": _SimplePoint(x=3, y=7)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "point.parquet" + pq.write_table(table, parquet_path) + + # Read path — another fresh context, no manual factory setup + read_converter = create_registry().get_context().type_converter + read_table = read_converter.load_extension_types(pq.read_table(parquet_path)) + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + result = rows_out[0]["point"] + assert isinstance(result, _SimplePoint) + assert result.x == 3 + assert result.y == 7 + + +def test_default_context_pydantic_parquet_roundtrip(tmp_path): + """Pydantic model round-trips through Parquet with no manual factory registration.""" + # Write path — fresh context, no manual factory setup + write_converter = create_registry().get_context().type_converter + write_converter.register_python_class(_SimpleModel) + arrow_schema = write_converter.python_schema_to_arrow_schema({"model": _SimpleModel}) + rows = [{"model": _SimpleModel(name="alice", score=9.5)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "model.parquet" + pq.write_table(table, parquet_path) + + # Read path — another fresh context, no manual factory setup + read_converter = create_registry().get_context().type_converter + read_table = read_converter.load_extension_types(pq.read_table(parquet_path)) + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + result = rows_out[0]["model"] + assert isinstance(result, _SimpleModel) + assert result.name == "alice" + assert result.score == 9.5 diff --git a/tests/test_extension_types/test_protocols.py b/tests/test_extension_types/test_protocols.py index 71892fdd..dee88998 100644 --- a/tests/test_extension_types/test_protocols.py +++ b/tests/test_extension_types/test_protocols.py @@ -1,56 +1,156 @@ -"""Tests for ExtensionTypeConverter protocol.""" +"""Tests for LogicalTypeProtocol, LogicalTypeFactoryProtocol, and TypeConverterProtocol.""" from __future__ import annotations import pyarrow as pa +import polars as pl -from orcapod.extension_types.protocols import ExtensionTypeConverter +from orcapod.extension_types.protocols import LogicalTypeProtocol +from orcapod.extension_types.registry import make_arrow_extension_type -class _StubConverter: - """Minimal conforming implementation of ExtensionTypeConverter for use in tests.""" +class _StubLogicalType: + """Minimal conforming implementation of LogicalTypeProtocol for use in tests.""" - @property - def extension_name(self) -> str: - return "test.module.MyType" + _ArrowExtClass = make_arrow_extension_type("test.module.MyType", pa.large_string()) @property - def extension_metadata(self) -> bytes | None: - return b"test.category" - - @property - def storage_type(self) -> pa.DataType: - return pa.large_string() + def logical_type_name(self) -> str: + return "test.module.MyType" @property def python_type(self) -> type: return str - def python_to_storage(self, value): + def get_arrow_extension_type(self) -> pa.ExtensionType: + return self._ArrowExtClass() + + def get_polars_extension_type(self) -> pl.BaseExtension: + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__("test.module.MyType", pl.String, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + return _PolarsExt() + + def python_to_storage(self, value, converter): # converter param added return str(value) - def storage_to_python(self, storage_value): + def storage_to_python(self, storage_value, converter): # converter param added return storage_value +class _StubFactory: + """Minimal conforming implementation of LogicalTypeFactoryProtocol for use in tests.""" + + def supports_class(self, python_type): # new method + return True + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata, converter): + return _StubLogicalType() + + def create_for_python_type(self, python_type, converter): + return _StubLogicalType() + + +def test_type_converter_protocol_is_importable(): + from orcapod.extension_types.protocols import TypeConverterProtocol + assert TypeConverterProtocol is not None + + +def test_factory_supports_class_method_required(): + """LogicalTypeFactoryProtocol requires supports_class.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + + class _BadFactory: + def reconstruct_from_arrow(self, name, storage_type, metadata, converter): + pass + def create_for_python_type(self, python_type, converter): + pass + # Missing supports_class + + assert not isinstance(_BadFactory(), LogicalTypeFactoryProtocol) + + +def test_factory_with_supports_class_satisfies_protocol(): + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + + class _GoodFactory: + def supports_class(self, python_type): + return True + def reconstruct_from_arrow(self, name, storage_type, metadata, converter): + pass + def create_for_python_type(self, python_type, converter): + pass + + assert isinstance(_GoodFactory(), LogicalTypeFactoryProtocol) + + +def test_logical_type_python_to_storage_accepts_converter(): + """LogicalTypeProtocol.python_to_storage now requires converter param.""" + from orcapod.extension_types.protocols import LogicalTypeProtocol + + class _GoodLT: + @property + def logical_type_name(self): return "test.lt" + @property + def python_type(self): return str + def get_arrow_extension_type(self): pass + def get_polars_extension_type(self): pass + def python_to_storage(self, value, converter): return value + def storage_to_python(self, storage_value, converter): return storage_value + + assert isinstance(_GoodLT(), LogicalTypeProtocol) + + +def test_logical_type_factory_protocol_is_importable(): + """LogicalTypeFactoryProtocol can be imported from extension_types.protocols.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + assert LogicalTypeFactoryProtocol is not None + + +def test_logical_type_factory_conforming_class_satisfies_protocol(): + """A conforming class is recognized as a LogicalTypeFactoryProtocol instance.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol + assert isinstance(_StubFactory(), LogicalTypeFactoryProtocol) + + +def test_logical_type_factory_create_returns_logical_type(): + """A conforming factory returns a LogicalTypeProtocol from reconstruct_from_arrow.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + result = factory.reconstruct_from_arrow( + "test.ext", pa.large_utf8(), {"category": "Test"}, converter=None + ) + assert isinstance(result, LogicalTypeProtocol) + + def test_protocol_is_importable(): - """ExtensionTypeConverter can be imported from extension_types.protocols.""" - assert ExtensionTypeConverter is not None + """LogicalTypeProtocol can be imported from extension_types.protocols.""" + assert LogicalTypeProtocol is not None def test_protocol_defines_required_members(): - """A conforming class is recognized as an ExtensionTypeConverter instance.""" - assert isinstance(_StubConverter(), ExtensionTypeConverter) + """A conforming class is recognized as a LogicalTypeProtocol instance.""" + assert isinstance(_StubLogicalType(), LogicalTypeProtocol) def test_conforming_class_satisfies_protocol(): """A class implementing all required members works correctly via the protocol interface.""" - converter: ExtensionTypeConverter = _StubConverter() - assert converter.extension_name == "test.module.MyType" - assert converter.extension_metadata == b"test.category" - assert converter.storage_type == pa.large_string() - assert converter.python_type is str - assert converter.python_to_storage(42) == "42" - assert converter.storage_to_python("hello") == "hello" - - + lt: LogicalTypeProtocol = _StubLogicalType() + assert lt.logical_type_name == "test.module.MyType" + assert lt.python_type is str + assert lt.get_arrow_extension_type().extension_name == "test.module.MyType" + assert isinstance(lt.get_polars_extension_type(), pl.BaseExtension) + assert lt.python_to_storage(42, None) == "42" # pass converter=None + assert lt.storage_to_python("hello", None) == "hello" # pass converter=None + + +def test_factory_create_for_python_type_conformance(): + """A conforming factory implements create_for_python_type and returns LogicalTypeProtocol.""" + from orcapod.extension_types.protocols import LogicalTypeFactoryProtocol, LogicalTypeProtocol + factory: LogicalTypeFactoryProtocol = _StubFactory() + assert isinstance(factory, LogicalTypeFactoryProtocol) + result = factory.create_for_python_type(str, converter=None) + assert isinstance(result, LogicalTypeProtocol) diff --git a/tests/test_extension_types/test_pydantic_logical_type_factory.py b/tests/test_extension_types/test_pydantic_logical_type_factory.py new file mode 100644 index 00000000..9c8afaf1 --- /dev/null +++ b/tests/test_extension_types/test_pydantic_logical_type_factory.py @@ -0,0 +1,460 @@ +"""Tests for PydanticLogicalType and PydanticLogicalTypeFactory.""" + +from __future__ import annotations + +from typing import Any + +import uuid as _uuid_module + +import pyarrow as pa +import pytest +from pydantic import BaseModel, PrivateAttr + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +class _StubConverter: + """Minimal converter stub for PydanticLogicalType tests.""" + + def python_to_storage(self, value, annotation): + if annotation is str: + return str(value) + if annotation is int: + return int(value) + return value + + def storage_to_python(self, storage_value, annotation): + if annotation is str: + return str(storage_value) + if annotation is int: + return int(storage_value) + return storage_value + + def register_python_class(self, annotation): + if annotation is str: + return pa.large_string() + if annotation is int: + return pa.int64() + raise ValueError(f"No mapping for {annotation}") + + +# ── PydanticLogicalType tests ──────────────────────────────────────────────── + +def test_pydantic_logical_type_is_importable(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + assert PydanticLogicalType is not None + + +def test_pydantic_logical_type_protocol_conformance(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + from orcapod.extension_types.protocols import LogicalTypeProtocol + + class _MyModel(BaseModel): + name: str + count: int + + storage = pa.struct([pa.field("name", pa.large_string()), pa.field("count", pa.int64())]) + lt = PydanticLogicalType( + logical_name="tests._MyModel", + python_type=_MyModel, + storage_type=storage, + field_annotations=[("name", str), ("count", int)], + ) + assert isinstance(lt, LogicalTypeProtocol) + + +def test_pydantic_logical_type_python_to_storage(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Point(BaseModel): + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = PydanticLogicalType("tests._Point", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.python_to_storage(_Point(x=3, y=7), converter) + assert result == {"x": 3, "y": 7} + + +def test_pydantic_logical_type_storage_to_python(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Point(BaseModel): + x: int + y: int + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + lt = PydanticLogicalType("tests._Point2", _Point, storage, [("x", int), ("y", int)]) + converter = _StubConverter() + + result = lt.storage_to_python({"x": 3, "y": 7}, converter) + assert isinstance(result, _Point) + assert result.x == 3 + assert result.y == 7 + + +def test_pydantic_logical_type_logical_type_name(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Foo(BaseModel): + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = PydanticLogicalType("mymod.Foo", _Foo, storage, [("val", str)]) + assert lt.logical_type_name == "mymod.Foo" + + +def test_pydantic_logical_type_python_type(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _Bar(BaseModel): + val: str + + storage = pa.struct([pa.field("val", pa.large_string())]) + lt = PydanticLogicalType("mymod.Bar", _Bar, storage, [("val", str)]) + assert lt.python_type is _Bar + + +def test_python_to_storage_raises_when_converter_none(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _DC(BaseModel): + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = PydanticLogicalType("mymod._DC", _DC, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.python_to_storage(_DC(x=1), None) + + +def test_storage_to_python_raises_when_converter_none(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalType + + class _DC2(BaseModel): + x: int + + storage = pa.struct([pa.field("x", pa.int64())]) + lt = PydanticLogicalType("mymod._DC2", _DC2, storage, [("x", int)]) + with pytest.raises(ValueError, match="converter"): + lt.storage_to_python({"x": 1}, None) + + +# ── Module-level models for factory tests ──────────────────────────────────── +# Must be at module scope (not inside functions) so FQCN reconstruction works. + +class _FlatModel(BaseModel): + name: str + count: int + + +class _ModelWithUUID(BaseModel): + id: _uuid_module.UUID + label: str + + +class _ModelWithList(BaseModel): + tags: list[str] + count: int + + +class _ModelWithDict(BaseModel): + meta: dict[str, int] + + +class _InnerModel(BaseModel): + value: int + + +class _OuterModel(BaseModel): + inner: _InnerModel + label: str + + +class _ModelWithPrivateAttr(BaseModel): + name: str + _cache: str = PrivateAttr(default="") + + +# ── Module-level models for read-path and round-trip tests ─────────────────── + +class _RoundTripPoint(BaseModel): + x: int + y: int + + +class _RoundTripRecord(BaseModel): + record_id: _uuid_module.UUID + label: str + + +# ── Factory helper ──────────────────────────────────────────────────────────── + +def _make_full_converter(): + """Make a UniversalTypeConverter with builtin types + PydanticLogicalTypeFactory.""" + from pydantic import BaseModel as _BaseModel + from orcapod.extension_types.builtin_logical_types import LogicalPath, LogicalUUID, LogicalUPath + from orcapod.extension_types.registry import LogicalTypeRegistry + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PYDANTIC_CATEGORY + from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + registry = LogicalTypeRegistry(logical_types=[LogicalPath(), LogicalUUID(), LogicalUPath()]) + factory = PydanticLogicalTypeFactory() + registry.register_logical_type_factory(factory, category=PYDANTIC_CATEGORY, python_bases=[_BaseModel]) + return UniversalTypeConverter(logical_type_registry=registry) + + +# ── PydanticLogicalTypeFactory write-path tests ─────────────────────────────── + +def test_factory_supports_class_pydantic_model(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + assert factory.supports_class(_FlatModel) is True + + +def test_factory_supports_class_non_pydantic(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + import dataclasses + + @dataclasses.dataclass + class _DC: + x: int + + factory = PydanticLogicalTypeFactory() + assert factory.supports_class(str) is False + assert factory.supports_class(int) is False + assert factory.supports_class(_DC) is False + + +def test_factory_create_flat_model(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PydanticLogicalType + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_FlatModel, converter=converter) + + assert isinstance(lt, PydanticLogicalType) + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_struct(storage) + assert storage.field("name").type == pa.large_string() + assert storage.field("count").type == pa.int64() + + +def test_factory_create_model_with_uuid_field(): + """UUID field → plain storage type (large_binary) in the struct, not extension type (ET1).""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithUUID, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + id_field_type = storage.field("id").type + assert id_field_type == pa.large_binary() + assert not isinstance(id_field_type, pa.ExtensionType) + + +def test_factory_create_model_with_list_field(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithList, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + assert pa.types.is_large_list(storage.field("tags").type) + assert storage.field("tags").type.value_type == pa.large_string() + + +def test_factory_create_model_with_dict_field(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithDict, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + meta_type = storage.field("meta").type + assert pa.types.is_large_list(meta_type) + assert pa.types.is_struct(meta_type.value_type) + field_names = {meta_type.value_type.field(i).name for i in range(meta_type.value_type.num_fields)} + assert field_names == {"key", "value"} + + +def test_factory_rejects_local_class(): + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + def _make_local(): + class _Local(BaseModel): + x: int + return _Local + + LocalModel = _make_local() + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + with pytest.raises(ValueError, match="local"): + factory.create_for_python_type(LocalModel, converter=converter) + + +def test_private_fields_not_stored(): + """Private attributes (PrivateAttr) must not appear in the Arrow struct.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.create_for_python_type(_ModelWithPrivateAttr, converter=converter) + + storage = lt.get_arrow_extension_type().storage_type + field_names = {storage.field(i).name for i in range(storage.num_fields)} + assert "name" in field_names + assert "_cache" not in field_names + assert storage.num_fields == 1 + + +# ── PydanticLogicalTypeFactory read-path tests ──────────────────────────────── + +def test_factory_reconstruct_from_arrow(): + """reconstruct_from_arrow rebuilds the logical type from the Arrow struct.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory, PydanticLogicalType + + storage = pa.struct([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + metadata = {"category": "orcapod.pydantic"} + fqcn = f"{_RoundTripPoint.__module__}.{_RoundTripPoint.__qualname__}" + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + lt = factory.reconstruct_from_arrow(fqcn, storage, metadata, converter=converter) + + assert isinstance(lt, PydanticLogicalType) + assert lt.python_type is _RoundTripPoint + assert lt.logical_type_name == fqcn + + +def test_factory_reconstruct_from_arrow_invalid_fqcn(): + """ImportError if the FQCN cannot be resolved.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + storage = pa.struct([pa.field("x", pa.int64())]) + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + + with pytest.raises(ImportError): + factory.reconstruct_from_arrow( + "nonexistent.module.NoSuchModel", storage, {"category": "orcapod.pydantic"}, converter + ) + + +def test_reconstruct_from_arrow_registers_nested_types(): + """reconstruct_from_arrow for Outer must register Inner as a side effect.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + inner_storage = pa.struct([pa.field("value", pa.int64())]) + outer_storage = pa.struct([ + pa.field("inner", inner_storage), + pa.field("label", pa.large_string()), + ]) + outer_fqcn = f"{_OuterModel.__module__}.{_OuterModel.__qualname__}" + + factory = PydanticLogicalTypeFactory() + converter = _make_full_converter() + + # Inner is NOT pre-registered + assert converter._logical_type_registry.get_by_python_type(_InnerModel) is None + + factory.reconstruct_from_arrow(outer_fqcn, outer_storage, {"category": "orcapod.pydantic"}, converter) + + # Inner must now be registered as a side effect + assert converter._logical_type_registry.get_by_python_type(_InnerModel) is not None + + +# ── Value round-trip tests ──────────────────────────────────────────────────── + +def test_pydantic_python_to_storage_round_trip(): + """python_to_storage → storage_to_python returns an equivalent model.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + converter = _make_full_converter() + factory = PydanticLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripPoint, converter=converter) + converter.register_logical_type(lt) + + point = _RoundTripPoint(x=10, y=20) + storage_value = lt.python_to_storage(point, converter) + assert storage_value == {"x": 10, "y": 20} + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripPoint) + assert reconstructed.x == 10 + assert reconstructed.y == 20 + + +def test_pydantic_with_uuid_round_trip(): + """Round-trip a pydantic model with a UUID field.""" + from orcapod.extension_types.pydantic_logical_type_factory import PydanticLogicalTypeFactory + + converter = _make_full_converter() + factory = PydanticLogicalTypeFactory() + lt = factory.create_for_python_type(_RoundTripRecord, converter=converter) + converter.register_logical_type(lt) + + u = _uuid_module.UUID("12345678-1234-5678-1234-567812345678") + record = _RoundTripRecord(record_id=u, label="hello") + + storage_value = lt.python_to_storage(record, converter) + assert storage_value["label"] == "hello" + assert storage_value["record_id"] == u.bytes + + reconstructed = lt.storage_to_python(storage_value, converter) + assert isinstance(reconstructed, _RoundTripRecord) + assert reconstructed.record_id == u + assert reconstructed.label == "hello" + + +# ── Parquet integration test ────────────────────────────────────────────────── + +def test_nested_pydantic_model_parquet_roundtrip(tmp_path): + """Fresh-process Parquet round-trip for a two-level nested pydantic model. + + Verifies that register_discovered_extensions triggers the chain: + register_arrow_extension("Outer") -> reconstruct_from_arrow + -> register_python_class(Inner) -> registers Inner + so that storage_to_python can reconstruct the full nested object. + """ + import pyarrow.parquet as pq + from orcapod.extension_types.database_hooks import register_discovered_extensions, apply_extension_types + + # ── Write path ─────────────────────────────────────────────────────────── + write_converter = _make_full_converter() + + inner = _InnerModel(value=42) + outer = _OuterModel(inner=inner, label="hello") + + write_converter.register_python_class(_OuterModel) + + arrow_schema = write_converter.python_schema_to_arrow_schema({"item": _OuterModel}) + rows = [{"item": outer}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + + parquet_path = tmp_path / "nested_pydantic.parquet" + pq.write_table(table, parquet_path) + + # ── Read path (fresh converter — neither Inner nor Outer pre-registered) ── + read_converter = _make_full_converter() + read_table = pq.read_table(parquet_path) + + register_discovered_extensions(read_converter, read_table.schema) + read_table = apply_extension_types(read_table, read_converter._logical_type_registry) + + assert read_converter._logical_type_registry.get_by_python_type(_OuterModel) is not None + assert read_converter._logical_type_registry.get_by_python_type(_InnerModel) is not None + + rows_out = read_converter.arrow_table_to_python_dicts(read_table) + assert len(rows_out) == 1 + reconstructed = rows_out[0]["item"] + assert isinstance(reconstructed, _OuterModel) + assert isinstance(reconstructed.inner, _InnerModel) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" diff --git a/tests/test_extension_types/test_registry.py b/tests/test_extension_types/test_registry.py new file mode 100644 index 00000000..970bbf72 --- /dev/null +++ b/tests/test_extension_types/test_registry.py @@ -0,0 +1,638 @@ +"""Tests for LogicalTypeRegistry and make_arrow_extension_type.""" + +from __future__ import annotations + +import json +import pathlib +import tempfile +import uuid +import warnings + +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from orcapod.extension_types.protocols import LogicalTypeProtocol, LogicalTypeFactoryProtocol +from orcapod.extension_types.registry import LogicalTypeRegistry, make_arrow_extension_type, make_polars_extension_type + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _unique_name() -> str: + """Unique extension name to avoid cross-test global-registry collisions.""" + return f"test.registry.{uuid.uuid4().hex[:8]}" + + +def _make_stub( + arrow_name: str | None = None, + logical_name: str | None = None, + storage: pa.DataType | None = None, + py_type: type = str, +) -> LogicalTypeProtocol: + """Factory for minimal LogicalTypeProtocol conforming stubs. + + ``arrow_name`` defaults to ``logical_name`` (or a unique name if both are + omitted) so that callers can pass a single name and get consistent arrow + and logical names. + """ + _arrow_name = arrow_name or logical_name or _unique_name() + _logical_name = logical_name or _arrow_name + _storage = storage if storage is not None else pa.large_utf8() + + ArrowExtClass = make_arrow_extension_type(_arrow_name, _storage) + + _pl_storage = pl.from_arrow(pa.array([], type=_storage)).dtype + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(_arrow_name, _pl_storage, None) + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _Stub: + @property + def logical_type_name(self) -> str: + return _logical_name + + @property + def python_type(self) -> type: + return py_type + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return ArrowExtClass() + + def get_polars_extension_type(self) -> pl.BaseExtension: + return _PolarsExt() + + def python_to_storage(self, value): + return str(value) + + def storage_to_python(self, storage_value): + return storage_value + + return _Stub() + + +def _make_stub_factory(return_lt: LogicalTypeProtocol | None = None) -> LogicalTypeFactoryProtocol: + """Factory for minimal LogicalTypeFactoryProtocol conforming stubs. + + If ``return_lt`` is given, ``reconstruct_from_arrow`` returns it; otherwise + it creates a fresh stub using ``_make_stub`` keyed on the arrow name. + ``calls`` records every invocation as ``(arrow_extension_name, storage_type, metadata)``. + ``python_type_calls`` records every ``create_for_python_type`` invocation. + """ + _return_lt = return_lt + + class _Factory: + def __init__(self): + self.calls: list[tuple] = [] + self.python_type_calls: list[type] = [] + + def reconstruct_from_arrow(self, arrow_extension_name, storage_type, metadata): + self.calls.append((arrow_extension_name, storage_type, metadata)) + if _return_lt is not None: + return _return_lt + return _make_stub(arrow_name=arrow_extension_name, storage=storage_type) + + def create_for_python_type(self, python_type): + self.python_type_calls.append(python_type) + if _return_lt is not None: + return _return_lt + return _make_stub(py_type=python_type) + + return _Factory() + + +# --------------------------------------------------------------------------- +# make_arrow_extension_type tests +# --------------------------------------------------------------------------- + +def test_make_arrow_extension_type_returns_class(): + """make_arrow_extension_type returns a pa.ExtensionType subclass.""" + cls = make_arrow_extension_type("test.MakeExt", pa.large_utf8()) + assert issubclass(cls, pa.ExtensionType) + + +def test_make_arrow_extension_type_instance_has_correct_name(): + """Instantiating the returned class yields the correct extension_name.""" + name = _unique_name() + cls = make_arrow_extension_type(name, pa.large_utf8()) + inst = cls() + assert inst.extension_name == name + + +def test_make_arrow_extension_type_instance_has_correct_storage(): + """Instantiating the returned class yields the correct storage_type.""" + cls = make_arrow_extension_type(_unique_name(), pa.large_binary()) + inst = cls() + assert inst.storage_type == pa.large_binary() + + +def test_make_arrow_extension_type_metadata_defaults_to_empty(): + """Without metadata, __arrow_ext_serialize__ returns empty bytes.""" + cls = make_arrow_extension_type(_unique_name(), pa.large_utf8()) + inst = cls() + assert inst.__arrow_ext_serialize__() == b"" + + +def test_make_arrow_extension_type_metadata_roundtrip(): + """With metadata, __arrow_ext_serialize__ returns the provided bytes.""" + meta = b"orcapod.test" + cls = make_arrow_extension_type(_unique_name(), pa.large_utf8(), metadata=meta) + inst = cls() + assert inst.__arrow_ext_serialize__() == meta + + +# --------------------------------------------------------------------------- +# LogicalTypeRegistry unit tests +# Each test uses a fresh LogicalTypeRegistry() instance. Registering does +# touch the global PA/Polars registries, but unique extension names (via +# _unique_name()) prevent cross-test collisions. +# --------------------------------------------------------------------------- + +def test_register_stores_logical_type(): + registry = LogicalTypeRegistry() + lt = _make_stub() + registry.register_logical_type(lt) + assert registry.get_by_logical_name(lt.logical_type_name) is lt + + +def test_register_same_instance_twice_is_idempotent(): + """Re-registering the exact same instance does not raise.""" + registry = LogicalTypeRegistry() + lt = _make_stub() + registry.register_logical_type(lt) + registry.register_logical_type(lt) # should not raise + assert registry.get_by_logical_name(lt.logical_type_name) is lt + + +def test_register_conflict_on_logical_name_raises(): + """Two different instances with the same logical_type_name raise ValueError.""" + registry = LogicalTypeRegistry() + name = _unique_name() + lt1 = _make_stub(logical_name=name, py_type=str) + lt2 = _make_stub(logical_name=name, py_type=bytes) + registry.register_logical_type(lt1) + with pytest.raises(ValueError, match="logical_type_name"): + registry.register_logical_type(lt2) + + +def test_register_conflict_on_arrow_name_raises(): + """Two different logical types sharing the same Arrow extension name raise ValueError.""" + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + lt1 = _make_stub(arrow_name=arrow_name, logical_name=_unique_name(), py_type=str) + lt2 = _make_stub(arrow_name=arrow_name, logical_name=_unique_name(), py_type=bytes) + registry.register_logical_type(lt1) + with pytest.raises(ValueError, match="arrow_extension_name"): + registry.register_logical_type(lt2) + + +def test_register_conflict_on_python_type_raises(): + """Two different logical types sharing the same python_type raise ValueError.""" + registry = LogicalTypeRegistry() + lt1 = _make_stub(py_type=float) + lt2 = _make_stub(py_type=float) + registry.register_logical_type(lt1) + with pytest.raises(ValueError, match="python_type"): + registry.register_logical_type(lt2) + + +def test_get_by_logical_name_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_logical_name("does.not.exist") is None + + +def test_get_by_python_type_exact(): + registry = LogicalTypeRegistry() + lt = _make_stub(py_type=bytes) + registry.register_logical_type(lt) + assert registry.get_by_python_type(bytes) is lt + + +def test_get_by_python_type_subclass(): + class _Base: + pass + + class _Child(_Base): + pass + + registry = LogicalTypeRegistry() + lt = _make_stub(py_type=_Base) + registry.register_logical_type(lt) + assert registry.get_by_python_type(_Child) is lt + + +def test_get_by_python_type_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_python_type(int) is None + + +def test_get_by_arrow_extension_name(): + registry = LogicalTypeRegistry() + arrow_name = _unique_name() + lt = _make_stub(arrow_name=arrow_name) + registry.register_logical_type(lt) + assert registry.get_by_arrow_extension_name(arrow_name) is lt + + +def test_get_by_arrow_extension_name_miss(): + registry = LogicalTypeRegistry() + assert registry.get_by_arrow_extension_name("does.not.exist") is None + + +# --------------------------------------------------------------------------- +# LogicalTypeRegistry constructor logical_types param tests +# --------------------------------------------------------------------------- + +def test_registry_init_with_logical_types_preregisters(): + """LogicalTypeRegistry(logical_types=[lt]) makes the type immediately retrievable.""" + lt = _make_stub() + registry = LogicalTypeRegistry(logical_types=[lt]) + assert registry.get_by_logical_name(lt.logical_type_name) is lt + assert registry.get_by_python_type(lt.python_type) is lt + assert registry.get_by_arrow_extension_name(lt.get_arrow_extension_type().extension_name) is lt + + +def test_registry_init_with_none_is_empty(): + """LogicalTypeRegistry(logical_types=None) starts empty without error.""" + registry = LogicalTypeRegistry(logical_types=None) + assert registry.get_by_logical_name("anything") is None + + +def test_registry_init_with_empty_list_is_empty(): + """LogicalTypeRegistry(logical_types=[]) starts empty without error.""" + registry = LogicalTypeRegistry(logical_types=[]) + assert registry.get_by_logical_name("anything") is None + + +def test_registry_init_with_multiple_logical_types(): + """LogicalTypeRegistry(logical_types=[lt1, lt2]) registers both.""" + lt1 = _make_stub(py_type=int) + lt2 = _make_stub(py_type=float) + registry = LogicalTypeRegistry(logical_types=[lt1, lt2]) + assert registry.get_by_logical_name(lt1.logical_type_name) is lt1 + assert registry.get_by_logical_name(lt2.logical_type_name) is lt2 + + +# --------------------------------------------------------------------------- +# register_logical_type_factory tests +# --------------------------------------------------------------------------- + +def test_register_logical_type_factory_no_error(): + """register_logical_type_factory completes without raising.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="TestCat") # should not raise + + +def test_register_logical_type_factory_same_instance_idempotent(): + """Re-registering the same factory instance for the same category does not raise.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="Cat") + registry.register_logical_type_factory(factory, category="Cat") # should not raise + + +def test_register_duplicate_category_raises(): + """Registering a different factory for an already-registered category raises ValueError.""" + registry = LogicalTypeRegistry() + f1 = _make_stub_factory() + f2 = _make_stub_factory() + registry.register_logical_type_factory(f1, category="Cat") + with pytest.raises(ValueError, match="Cat"): + registry.register_logical_type_factory(f2, category="Cat") + + +def test_register_logical_type_factory_keyword_category(): + """register_logical_type_factory accepts factory as first arg, category as keyword.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="TestCat") # no error + + +def test_register_logical_type_factory_keyword_python_bases(): + """register_logical_type_factory accepts python_bases as keyword.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[str]) # no error + + +def test_register_logical_type_factory_both_axes(): + """register_logical_type_factory accepts both category and python_bases.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, category="Cat", python_bases=[str, int]) + + +def test_register_logical_type_factory_no_axes_raises(): + """register_logical_type_factory raises ValueError when called with no axes.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + with pytest.raises(ValueError, match="At least one of"): + registry.register_logical_type_factory(factory) + + +def test_register_logical_type_factory_python_base_duplicate_different_factory_raises(): + """Registering a different factory for the same python_base raises ValueError.""" + registry = LogicalTypeRegistry() + f1 = _make_stub_factory() + f2 = _make_stub_factory() + registry.register_logical_type_factory(f1, python_bases=[str]) + with pytest.raises(ValueError, match="different factory"): + registry.register_logical_type_factory(f2, python_bases=[str]) + + +def test_register_logical_type_factory_python_base_same_factory_idempotent(): + """Registering the same factory twice for the same python_base is a no-op.""" + registry = LogicalTypeRegistry() + factory = _make_stub_factory() + registry.register_logical_type_factory(factory, python_bases=[str]) + registry.register_logical_type_factory(factory, python_bases=[str]) # no error + +# --------------------------------------------------------------------------- +# PyArrow global registry tests +# --------------------------------------------------------------------------- + +def test_register_populates_arrow_registry(): + """After register(), PA global registry contains the extension type.""" + lt = _make_stub() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + # If the name is registered, attempting to re-register the same type raises + # ArrowKeyError. This is the only stable public signal PyArrow provides. + with pytest.raises(pa.lib.ArrowKeyError): + pa.register_extension_type(lt.get_arrow_extension_type()) + + +def test_register_arrow_preexisting_external_accepted_silently(): + """A name already registered externally in PyArrow is accepted silently (no raise).""" + name = _unique_name() + + class _External(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + def __arrow_ext_serialize__(self): + return b"" + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_External()) # bypass our registry + + # New semantics: pre-existing registrations are accepted silently. + lt = _make_stub(arrow_name=name) + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) # should NOT raise + assert registry.get_by_logical_name(lt.logical_type_name) is lt + + +def test_register_same_instance_two_registries(): + """The same LogicalTypeProtocol instance can be registered in two different registry instances.""" + lt = _make_stub() + r1 = LogicalTypeRegistry() + r2 = LogicalTypeRegistry() + r1.register_logical_type(lt) + r2.register_logical_type(lt) # should not raise (same instance, PA/Polars accept silently) + assert r2.get_by_logical_name(lt.logical_type_name) is lt + + +# --------------------------------------------------------------------------- +# Polars global registry tests +# --------------------------------------------------------------------------- + +def test_register_populates_polars_registry(): + """After register(), Polars knows the extension type.""" + arrow_name = _unique_name() + lt = _make_stub(arrow_name=arrow_name) + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + # Verify by attempting to create a Polars series from a PA extension array. + ext_type = lt.get_arrow_extension_type() + storage_arr = pa.array(["a", "b"], type=ext_type.storage_type) + ext_arr = storage_arr.cast(ext_type) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pl_series = pl.from_arrow(ext_arr) + + assert isinstance(pl_series.dtype, pl.BaseExtension) + assert pl_series.dtype.ext_name() == arrow_name + + +def test_register_polars_preexisting_external_accepted_silently(): + """A name already registered externally in Polars is accepted silently.""" + name = _unique_name() + + class _ExternalPL(pl.BaseExtension): + def __init__(self): + super().__init__(name, pl.String, None) + @classmethod + def ext_from_params(cls, n, s, m): + return cls() + + class _ExternalPA(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.large_utf8(), name) + def __arrow_ext_serialize__(self): + return b"" + @classmethod + def __arrow_ext_deserialize__(cls, st, se): + return cls() + + pa.register_extension_type(_ExternalPA()) + pl.register_extension_type(name, _ExternalPL) + + lt = _make_stub(arrow_name=name) + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) # should NOT raise + assert registry.get_by_logical_name(lt.logical_type_name) is lt + + +# --------------------------------------------------------------------------- +# End-to-end integration tests +# --------------------------------------------------------------------------- + + +class _Color: + """Minimal Python class used to exercise the LogicalTypeProtocol contract end-to-end.""" + def __init__(self, hex_str: str) -> None: + self.hex_str = hex_str + def __eq__(self, other: object) -> bool: + return isinstance(other, _Color) and self.hex_str == other.hex_str + def __repr__(self) -> str: + return f"Color({self.hex_str!r})" + + +def _make_color_logical_type() -> LogicalTypeProtocol: + """LogicalTypeProtocol for _Color, backed by pa.large_utf8() storage.""" + _name = _unique_name() + _ArrowExtClass = make_arrow_extension_type(_name, pa.large_utf8(), metadata=b"test.color") + + class _PolarsExt(pl.BaseExtension): + def __init__(self): + super().__init__(_name, pl.String, "test.color") + @classmethod + def ext_from_params(cls, ext_name, storage_dtype, metadata_str): + return cls() + + class _ColorLogicalType: + @property + def logical_type_name(self) -> str: + return _name + + @property + def python_type(self) -> type: + return _Color + + def get_arrow_extension_type(self) -> pa.ExtensionType: + return _ArrowExtClass() + + def get_polars_extension_type(self) -> pl.BaseExtension: + return _PolarsExt() + + def python_to_storage(self, value: _Color) -> str: + return value.hex_str + + def storage_to_python(self, storage_value: str) -> _Color: + return _Color(storage_value) + + return _ColorLogicalType() + + +def _build_ext_array( + lt: LogicalTypeProtocol, + values: list, +) -> pa.Array: + """Build a PA extension array from Python values using the logical type. + + Global registration (via ``registry.register_logical_type(lt)``) is NOT required for + this helper — ``cast()`` works with any ``pa.ExtensionType`` instance. + Registration is only needed for IPC/Parquet *deserialization*, where Arrow + maps the ``extension_name`` string back to the registered Python type. + """ + storage_values = [lt.python_to_storage(v) for v in values] + arrow_ext = lt.get_arrow_extension_type() + storage_arr = pa.array(storage_values, type=arrow_ext.storage_type) + return storage_arr.cast(arrow_ext) + + +def test_python_class_round_trip(): + """Python objects -> Arrow extension array -> Python objects via logical type methods.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [_Color("#ff0000"), _Color("#00ff00"), _Color("#0000ff")] + ext_arr = _build_ext_array(lt, originals) + + recovered = [lt.storage_to_python(v.as_py()) for v in ext_arr.storage] + assert recovered == originals + + +def test_arrow_polars_round_trip(): + """PA ext array -> pl.from_arrow -> to_arrow() preserves extension type and values.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [_Color("#aabbcc"), _Color("#112233")] + ext_arr = _build_ext_array(lt, originals) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pl_series = pl.from_arrow(ext_arr) + + assert isinstance(pl_series.dtype, pl.BaseExtension) + assert pl_series.dtype.ext_name() == lt.get_arrow_extension_type().extension_name + + arr_back = pl_series.to_arrow() + assert arr_back.type.extension_name == lt.get_arrow_extension_type().extension_name + + recovered = [lt.storage_to_python(v.as_py()) for v in arr_back.storage] + assert recovered == originals + + +def test_parquet_round_trip(): + """PA ext array -> Parquet -> read back via PyArrow; extension type and values preserved.""" + lt = _make_color_logical_type() + registry = LogicalTypeRegistry() + registry.register_logical_type(lt) + + originals = [_Color("#deadbe"), _Color("#cafeba")] + ext_arr = _build_ext_array(lt, originals) + arrow_ext = lt.get_arrow_extension_type() + schema = pa.schema([pa.field("color", arrow_ext), pa.field("id", pa.int32())]) + table = pa.table( + {"color": ext_arr, "id": pa.array([1, 2], type=pa.int32())}, + schema=schema, + ) + + with tempfile.TemporaryDirectory() as tmp: + path = pathlib.Path(tmp) / "test.parquet" + pq.write_table(table, path) + table_back = pq.read_table(path) + + assert table_back.schema.field("color").type.extension_name == arrow_ext.extension_name + storage_arr = table_back.column("color").combine_chunks().storage + recovered = [lt.storage_to_python(v.as_py()) for v in storage_arr] + assert recovered == originals + + +# --------------------------------------------------------------------------- +# make_polars_extension_type tests +# --------------------------------------------------------------------------- + + +def test_make_polars_extension_type_returns_class(): + """make_polars_extension_type returns a pl.BaseExtension subclass.""" + cls = make_polars_extension_type("test.MakePolarsExt", pa.large_utf8()) + assert issubclass(cls, pl.BaseExtension) + + +def test_make_polars_extension_type_instance_has_correct_name(): + """Instantiating the returned class yields the correct ext_name.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8()) + inst = cls() + assert inst.ext_name() == name + + +def test_make_polars_extension_type_ext_from_params_returns_instance(): + """ext_from_params classmethod returns an instance of the class.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8()) + inst = cls.ext_from_params(name, pl.String, None) + assert isinstance(inst, cls) + + +def test_make_polars_extension_type_with_binary_storage(): + """make_polars_extension_type works with pa.binary(16) storage (UUID case).""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.binary(16), None) + inst = cls() + assert inst.ext_name() == name + + +def test_make_polars_extension_type_with_metadata(): + """make_polars_extension_type captures metadata in the class.""" + name = _unique_name() + cls = make_polars_extension_type(name, pa.large_utf8(), "test.metadata") + # Instantiating should not raise; ext_name is correct. + inst = cls() + assert inst.ext_name() == name + + + + +def test_registry_does_not_expose_ensure_methods(): + """ensure_logical_type_for_python_class and ensure_extension_type are removed.""" + registry = LogicalTypeRegistry() + assert not hasattr(registry, "ensure_logical_type_for_python_class") + assert not hasattr(registry, "ensure_extension_type") diff --git a/tests/test_extension_types/test_roundtrips.py b/tests/test_extension_types/test_roundtrips.py new file mode 100644 index 00000000..afac59dc --- /dev/null +++ b/tests/test_extension_types/test_roundtrips.py @@ -0,0 +1,376 @@ +"""End-to-end integration tests for extension type round-trips. + +Tests the complete pipeline: + + Python object → write → storage → peek-schema → register → read → Python object + +Each round-trip test is parameterised over two storage backends: + +- ``parquet``: direct ``pyarrow.parquet`` write/read. +- ``delta``: ``deltalake.write_deltalake`` / ``DeltaTable.to_pyarrow_dataset(as_large_types=True).to_table()``. + +SQLite (``ConnectorArrowDatabase`` + ``SQLiteConnector``) is excluded because +``SQLiteConnector`` maps Arrow types to SQL column types and discards +``ARROW:extension:*`` field metadata. Without that metadata, the +peek-register-read pattern cannot auto-register extension types on the read +path. The ``ExtensionAwareDatabase`` wrapper behaviour over SQLite is already +tested in ``tests/test_databases/test_extension_aware_database.py``. +""" +from __future__ import annotations + +import dataclasses +import pathlib +import uuid as uuid_module +from pathlib import Path +from typing import Callable + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from upath import UPath + +from orcapod.contexts import create_registry +from orcapod.semantic_types.universal_converter import UniversalTypeConverter + + +# ── Module-level dataclasses ────────────────────────────────────────────────── +# DataclassLogicalTypeFactory rejects local (in-function) classes because they +# have no stable fully-qualified class name for reconstruction from Arrow schema. + +@dataclasses.dataclass +class _PointA: + x: int + y: int + + +@dataclasses.dataclass +class _PointB: + """Same struct shape as _PointA, different class name.""" + x: int + y: int + + +@dataclasses.dataclass +class _Inner: + value: int + + +@dataclasses.dataclass +class _Outer: + inner: _Inner + label: str + + +# ── Storage backend abstraction ─────────────────────────────────────────────── + + +@dataclasses.dataclass +class _StorageBackend: + """Encapsulates backend-specific write and read logic for parameterised tests. + + Args: + name: Short identifier used in pytest test IDs (e.g. ``"parquet"``). + write: Callable that writes an Arrow table to a directory. + read: Callable that reads from that directory and returns an Arrow table + with extension types registered and applied. Must return only the + original user data columns (no ``__record_id`` or similar). + """ + name: str + write: Callable[[pa.Table, Path], None] + read: Callable[[Path, UniversalTypeConverter], pa.Table] + + +def _parquet_write(table: pa.Table, base_path: Path) -> None: + pq.write_table(table, str(base_path / "data.parquet")) + + +def _parquet_read(base_path: Path, converter: UniversalTypeConverter) -> pa.Table: + return converter.load_extension_types(pq.read_table(str(base_path / "data.parquet"))) + + +def _delta_write(table: pa.Table, base_path: Path) -> None: + import deltalake + deltalake.write_deltalake(str(base_path / "delta"), table) + + +def _delta_read(base_path: Path, converter: UniversalTypeConverter) -> pa.Table: + import deltalake + dt = deltalake.DeltaTable(str(base_path / "delta")) + # as_large_types=True preserves large_string / large_binary rather than + # normalising them to string / binary (Delta Lake's default behaviour). + # Without this flag, extension types that use large_string or large_binary + # as storage fail to deserialise because the _deserialize method strictly + # checks that the storage type matches the registered one. + raw = dt.to_pyarrow_dataset(as_large_types=True).to_table() + return converter.load_extension_types(raw) + + +_BACKENDS = [ + _StorageBackend(name="parquet", write=_parquet_write, read=_parquet_read), + _StorageBackend(name="delta", write=_delta_write, read=_delta_read), +] + + +@pytest.fixture(params=_BACKENDS, ids=lambda b: b.name) +def storage_backend(request: pytest.FixtureRequest) -> _StorageBackend: + """Yield one storage backend per parametrised run.""" + return request.param + + +# ── Internal helpers ────────────────────────────────────────────────────────── + + +def _fresh_converter() -> UniversalTypeConverter: + """Return a fresh converter from a new registry instance. + + Uses ``create_registry()`` instead of ``get_default_context()`` to avoid + cross-test contamination through the global singleton cache. + """ + return create_registry().get_context().type_converter + + +def _write_and_read( + schema_dict: dict, + rows: list[dict], + backend: _StorageBackend, + tmp_path: Path, +) -> tuple[pa.Table, UniversalTypeConverter]: + """Write rows with a fresh write converter and read back with a fresh read converter. + + Returns the resulting Arrow table (with extension types applied) and the + read-side converter (needed for ``arrow_table_to_python_dicts``). + """ + write_converter = _fresh_converter() + # Pre-register each type so the converter can map it to an Arrow extension + # type before python_schema_to_arrow_schema inspects it. Built-in types + # (Path, UPath, UUID) are already registered in the context; dataclass types + # are auto-discovered on the first register_python_class call. + for python_type in schema_dict.values(): + write_converter.register_python_class(python_type) + arrow_schema = write_converter.python_schema_to_arrow_schema(schema_dict) + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + backend.write(table, tmp_path) + + read_converter = _fresh_converter() + result = backend.read(tmp_path, read_converter) + return result, read_converter + + +# ── Built-in type round-trip tests ─────────────────────────────────────────── + + +def test_builtin_path_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """pathlib.Path round-trips through storage with extension name ``orcapod.path``. + + Built-in types (Path, UPath, UUID) are pre-registered in the default context + so the read-side converter already knows about them. The test verifies that: + + 1. The Arrow field carries the ``orcapod.path`` extension type after read. + 2. The Python value is reconstructed as a ``pathlib.Path`` instance. + """ + p = pathlib.Path("/tmp/orcapod/integration/test.txt") + result, read_converter = _write_and_read( + {"col": pathlib.Path}, + [{"col": p}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.path" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], pathlib.Path) + assert rows[0]["col"] == p + + +def test_builtin_upath_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """UPath round-trips through storage with extension name ``orcapod.upath``.""" + u = UPath("s3://my-bucket/data/file.parquet") + result, read_converter = _write_and_read( + {"col": UPath}, + [{"col": u}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.upath" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], UPath) + assert str(rows[0]["col"]) == str(u) + + +def test_builtin_uuid_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """uuid.UUID round-trips through storage with extension name ``orcapod.uuid``.""" + u = uuid_module.UUID("12345678-1234-5678-1234-567812345678") + result, read_converter = _write_and_read( + {"col": uuid_module.UUID}, + [{"col": u}], + storage_backend, + tmp_path, + ) + + field = result.schema.field("col") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'col', got plain type {field.type!r}" + ) + assert field.type.extension_name == "orcapod.uuid" + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + assert isinstance(rows[0]["col"], uuid_module.UUID) + assert rows[0]["col"] == u + + +# ── Dataclass round-trip tests ──────────────────────────────────────────────── + + +def test_simple_dataclass_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """Simple dataclass round-trips with correct FQCN as the Arrow extension name. + + The read-side converter starts with no knowledge of _PointA. After read, + register_discovered_extensions triggers DataclassLogicalTypeFactory which + imports _PointA from its fully-qualified class name and registers it. + """ + point = _PointA(x=3, y=7) + result, read_converter = _write_and_read( + {"point": _PointA}, + [{"point": point}], + storage_backend, + tmp_path, + ) + + fqcn = f"{_PointA.__module__}.{_PointA.__qualname__}" + field = result.schema.field("point") + assert hasattr(field.type, "extension_name"), ( + f"Expected extension type on field 'point', got {field.type!r}" + ) + assert field.type.extension_name == fqcn + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + reconstructed = rows[0]["point"] + assert isinstance(reconstructed, _PointA) + assert reconstructed.x == 3 + assert reconstructed.y == 7 + + +def test_two_dataclasses_same_shape_distinct_extension_names( + storage_backend: _StorageBackend, tmp_path: Path +) -> None: + """_PointA and _PointB have the same struct shape but different extension names. + + Writing _PointA and reading it back must NOT reconstruct a _PointB, even + though their on-disk struct shapes (x: int, y: int) are identical. The + extension name (FQCN) is the sole identity signal. + """ + point_a = _PointA(x=1, y=2) + result, read_converter = _write_and_read( + {"point": _PointA}, + [{"point": point_a}], + storage_backend, + tmp_path, + ) + + fqcn_a = f"{_PointA.__module__}.{_PointA.__qualname__}" + fqcn_b = f"{_PointB.__module__}.{_PointB.__qualname__}" + + field = result.schema.field("point") + assert hasattr(field.type, "extension_name") + assert field.type.extension_name == fqcn_a + assert field.type.extension_name != fqcn_b # distinct from _PointB + + rows = read_converter.arrow_table_to_python_dicts(result) + reconstructed = rows[0]["point"] + assert isinstance(reconstructed, _PointA) + assert not isinstance(reconstructed, _PointB) + + +def test_nested_dataclass_round_trip(storage_backend: _StorageBackend, tmp_path: Path) -> None: + """Nested dataclass: _Outer and _Inner both registered; full object reconstructed. + + register_discovered_extensions triggers DataclassLogicalTypeFactory for _Outer. + That factory's reconstruct_from_arrow calls converter.register_python_class(_Inner) + as a side-effect, so _Inner is also registered without an explicit peek step. + """ + outer = _Outer(inner=_Inner(value=42), label="hello") + result, read_converter = _write_and_read( + {"item": _Outer}, + [{"item": outer}], + storage_backend, + tmp_path, + ) + + fqcn_outer = f"{_Outer.__module__}.{_Outer.__qualname__}" + fqcn_inner = f"{_Inner.__module__}.{_Inner.__qualname__}" + + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn_outer) is not None, ( + "_Outer should be registered after read" + ) + assert read_converter._logical_type_registry.get_by_arrow_extension_name(fqcn_inner) is not None, ( + "_Inner should be registered transitively after read" + ) + + rows = read_converter.arrow_table_to_python_dicts(result) + assert len(rows) == 1 + reconstructed = rows[0]["item"] + assert isinstance(reconstructed, _Outer) + assert isinstance(reconstructed.inner, _Inner) + assert reconstructed.inner.value == 42 + assert reconstructed.label == "hello" + + +# ── Delta Lake: Polars native read ─────────────────────────────────────────── + + +def test_delta_polars_read_delta(tmp_path: Path) -> None: + """Write a dataclass column to Delta; read back via pl.read_delta; extension type survives. + + The write-side converter registers _PointA in both PyArrow's and Polars' + global registries (``register_python_class`` calls ``make_polars_extension_type`` + which registers with Polars). ``pl.read_delta`` can therefore decode the column + as the correct Polars extension type, not a plain ``Struct``. + + Note: ``pl.DataFrame.to_arrow()`` exports Polars extension types as PyArrow + extension arrays but with empty serialized bytes (Polars does not forward + ``__arrow_ext_metadata__`` through its Arrow export). Python-object + reconstruction via the Polars-to-Arrow path is therefore not possible; that + path is tested by the separate ``parquet`` / ``delta`` parametrised tests + which read underlying Parquet files directly. + """ + import deltalake + import polars as pl + + delta_path = str(tmp_path / "polars_delta") + fqcn = f"{_PointA.__module__}.{_PointA.__qualname__}" + + # Write — registers _PointA in PyArrow + Polars global registries. + write_converter = _fresh_converter() + write_converter.register_python_class(_PointA) + arrow_schema = write_converter.python_schema_to_arrow_schema({"point": _PointA}) + rows = [{"point": _PointA(x=5, y=9)}] + table = write_converter.python_dicts_to_arrow_table(rows, arrow_schema=arrow_schema) + deltalake.write_deltalake(delta_path, table) + + # Read via Polars native Delta reader. + # _PointA is already in the Polars global registry from the write step above. + df = pl.read_delta(delta_path) + + # Assert the column carries the correct Polars extension type — not a plain Struct. + col_dtype = df.dtypes[0] + assert col_dtype.is_extension(), ( + f"Expected a Polars extension type on column 'point', got {col_dtype!r}" + ) + assert col_dtype.ext_name() == fqcn, ( + f"Expected extension name {fqcn!r}, got {col_dtype.ext_name()!r}" + ) diff --git a/tests/test_extension_types/test_schema_compatibility.py b/tests/test_extension_types/test_schema_compatibility.py new file mode 100644 index 00000000..f15d190d --- /dev/null +++ b/tests/test_extension_types/test_schema_compatibility.py @@ -0,0 +1,106 @@ +"""Integration tests for extension-type-backed schema compatibility. + +Two complementary angles: + +Arrow-level identity + ``converter.python_schema_to_arrow_schema`` assigns each dataclass a unique + Arrow extension name derived from its fully-qualified class name. Two + dataclasses with identical struct shapes but different class names therefore + produce *different* extension names — the core identity guarantee of the + extension type system. + +Python-type-level compatibility + ``check_schema_compatibility`` from ``schema_utils`` uses beartype + ``is_subhint`` to compare Python type annotations. Same class → compatible; + different class with the same struct shape → incompatible. This is the + property that prevents silent data corruption when two unrelated dataclasses + happen to share the same fields. +""" +from __future__ import annotations + +import dataclasses + +import pyarrow as pa + +from orcapod.contexts import create_registry +from orcapod.types import Schema +from orcapod.utils.schema_utils import check_schema_compatibility + + +# Module-level dataclasses — DataclassLogicalTypeFactory rejects local classes +# because they have no stable fully-qualified class name for reconstruction. + +@dataclasses.dataclass +class _PointA: + x: int + y: int + + +@dataclasses.dataclass +class _PointB: + """Same struct shape as _PointA but a different class name.""" + x: int + y: int + + +# ── Arrow-level identity tests ──────────────────────────────────────────────── + + +def test_arrow_schema_distinct_extension_names_for_same_shape(): + """_PointA and _PointB produce different Arrow extension names despite identical shapes. + + This is the core identity guarantee: struct shape alone does not determine + type identity in the extension type system. + """ + converter_a = create_registry().get_context().type_converter + converter_b = create_registry().get_context().type_converter + + type_a = converter_a.register_python_class(_PointA) + type_b = converter_b.register_python_class(_PointB) + + assert isinstance(type_a, pa.ExtensionType) + assert isinstance(type_b, pa.ExtensionType) + + fqcn_a = f"{_PointA.__module__}.{_PointA.__qualname__}" + fqcn_b = f"{_PointB.__module__}.{_PointB.__qualname__}" + assert type_a.extension_name == fqcn_a + assert type_b.extension_name == fqcn_b + assert type_a.extension_name != type_b.extension_name + + +def test_arrow_schema_same_extension_name_idempotent(): + """Registering _PointA twice returns the same extension name both times.""" + converter = create_registry().get_context().type_converter + + type_first = converter.register_python_class(_PointA) + type_second = converter.register_python_class(_PointA) + + assert isinstance(type_first, pa.ExtensionType) + assert isinstance(type_second, pa.ExtensionType) + assert type_first.extension_name == type_second.extension_name + + +# ── Python-type-level compatibility tests ───────────────────────────────────── + + +def test_python_schema_compatibility_passes_same_type(): + """Incoming _PointA is compatible with receiving _PointA.""" + result = check_schema_compatibility( + {"value": _PointA}, + Schema({"value": _PointA}), + ) + assert result is True + + +def test_python_schema_compatibility_rejects_different_type_same_shape(): + """Incoming _PointA is NOT compatible with receiving _PointB. + + Both dataclasses share the same struct shape {x: int, y: int}, but they + are different Python types. The old shape-based system would have accepted + this silently; the extension type system correctly rejects it. + """ + result = check_schema_compatibility( + {"value": _PointA}, + Schema({"value": _PointB}), + ) + assert result is False diff --git a/tests/test_extension_types/test_schema_walker.py b/tests/test_extension_types/test_schema_walker.py new file mode 100644 index 00000000..33fe1bfa --- /dev/null +++ b/tests/test_extension_types/test_schema_walker.py @@ -0,0 +1,261 @@ +"""Tests for schema_walker — recursive Arrow extension type discovery.""" + +from __future__ import annotations + +import re +import uuid + +import pyarrow as pa +import pytest + +from orcapod.extension_types.schema_walker import ( + ExtensionTypeInfo, + walk_field, + walk_schema, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _unique_name() -> str: + """Return a unique extension name to avoid cross-test collisions.""" + return f"test.walker.{uuid.uuid4().hex[:8]}" + + +def _make_reg_field( + field_name: str, + ext_name: str, + storage: pa.DataType | None = None, + metadata: bytes = b"test.cat", +) -> pa.Field: + """Create a ``pa.Field`` with an in-memory ``pa.ExtensionType`` (registered channel). + + The extension type is NOT registered in PyArrow's global registry — this + is intentional. ``pa.types.is_extension(field.type)`` returns ``True`` + for any ``pa.ExtensionType`` instance regardless of global registration. + """ + _n = ext_name + _s = storage if storage is not None else pa.large_utf8() + _m = metadata + ExtType = type( + f"_RegExt_{re.sub(r'[^A-Za-z0-9]', '_', ext_name)}", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: _m, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + return pa.field(field_name, ExtType()) + + +def _make_unreg_field( + field_name: str, + ext_name: str, + storage: pa.DataType | None = None, + metadata: bytes = b"test.cat", +) -> pa.Field: + """Create a ``pa.Field`` with raw Arrow extension metadata (unregistered channel).""" + _s = storage if storage is not None else pa.large_utf8() + return pa.field( + field_name, + _s, + metadata={ + b"ARROW:extension:name": ext_name.encode(), + b"ARROW:extension:metadata": metadata, + }, + ) + + +# --------------------------------------------------------------------------- +# Task 1 tests: top-level detection and deduplication +# --------------------------------------------------------------------------- + + +def test_empty_schema(): + result = walk_schema(pa.schema([])) + assert result == [] + + +def test_no_extension_types(): + schema = pa.schema([ + pa.field("x", pa.int64()), + pa.field("y", pa.large_utf8()), + ]) + assert walk_schema(schema) == [] + + +def test_top_level_registered(): + name = _unique_name() + schema = pa.schema([_make_reg_field("col", name, metadata=b"my.cat")]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + assert result[0].storage_type == pa.large_utf8() + + +def test_top_level_unregistered(): + name = _unique_name() + schema = pa.schema([_make_unreg_field("col", name, metadata=b"my.cat")]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + assert result[0].storage_type == pa.large_utf8() + + +def test_empty_metadata_normalised_to_none_registered(): + """b'' from __arrow_ext_serialize__ is normalised to None.""" + name = _unique_name() + _n, _s = name, pa.large_utf8() + ExtType = type( + "_EmptyMetaExt", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: b"", + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + result = walk_field(pa.field("col", ExtType())) + assert len(result) == 1 + assert result[0].extension_metadata is None + + +def test_empty_metadata_normalised_to_none_unregistered(): + """b'' ARROW:extension:metadata value is normalised to None.""" + name = _unique_name() + field = pa.field( + "col", + pa.large_utf8(), + metadata={ + b"ARROW:extension:name": name.encode(), + b"ARROW:extension:metadata": b"", + }, + ) + result = walk_field(field) + assert len(result) == 1 + assert result[0].extension_metadata is None + + +def test_walk_field_returns_single_field_result(): + name = _unique_name() + field = _make_reg_field("col", name, metadata=b"cat") + result = walk_field(field) + assert len(result) == 1 + assert result[0].extension_name == name + + +def test_deduplication(): + """Same (extension_name, extension_metadata) in two columns → one result.""" + name = _unique_name() + meta = b"test.cat" + schema = pa.schema([ + _make_reg_field("col_a", name, metadata=meta), + _make_reg_field("col_b", name, metadata=meta), + ]) + result = walk_schema(schema) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == meta + + +# --------------------------------------------------------------------------- +# Task 2 tests: container recursion +# --------------------------------------------------------------------------- + + +def test_list_of_registered(): + """Registered extension type as the value field of a list.""" + name = _unique_name() + value_field = _make_reg_field("item", name, metadata=b"my.cat") + list_field = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([list_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_list_of_unregistered(): + """Unregistered extension type as the value field of a list.""" + name = _unique_name() + value_field = _make_unreg_field("item", name, metadata=b"my.cat") + list_field = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([list_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_struct_containing_registered(): + """Registered extension type as a field inside a struct.""" + name = _unique_name() + struct_field = pa.field( + "col", + pa.struct([ + _make_reg_field("a", name, metadata=b"my.cat"), + pa.field("b", pa.int64()), + ]), + ) + result = walk_schema(pa.schema([struct_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_struct_containing_unregistered(): + """Unregistered extension type as a field inside a struct.""" + name = _unique_name() + struct_field = pa.field( + "col", + pa.struct([ + _make_unreg_field("a", name, metadata=b"my.cat"), + pa.field("b", pa.int64()), + ]), + ) + result = walk_schema(pa.schema([struct_field])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"my.cat" + + +def test_nested_list_struct(): + """Registered extension type nested inside list>.""" + name = _unique_name() + struct_type = pa.struct([ + _make_reg_field("x", name, metadata=b"deep.cat"), + pa.field("y", pa.int32()), + ]) + value_field = pa.field("item", struct_type) + col = pa.field("col", pa.list_(value_field)) + result = walk_schema(pa.schema([col])) + assert len(result) == 1 + assert result[0].extension_name == name + assert result[0].extension_metadata == b"deep.cat" + + +def test_map_type(): + """Extension type as the item type of a map (registered channel).""" + name = _unique_name() + _n, _m, _s = name, b"map.cat", pa.large_utf8() + # Build a pa.ExtensionType instance — it IS a pa.DataType and can be + # passed directly to pa.map_() as the item type. + ExtType = type( + "_MapItemExt", + (pa.ExtensionType,), + { + "__init__": lambda self: pa.ExtensionType.__init__(self, _s, _n), + "__arrow_ext_serialize__": lambda self: _m, + "__arrow_ext_deserialize__": classmethod(lambda cls, st, se: cls()), + }, + ) + map_field = pa.field("col", pa.map_(pa.large_utf8(), ExtType())) + result = walk_schema(pa.schema([map_field])) + # _collect uses getattr(t, "item_field") to retrieve the item pa.Field. + # pa.types.is_extension(item_field.type) will be True for the ExtType above. + assert len(result) == 1 + assert result[0].extension_name == name diff --git a/tests/test_extension_types/test_type_utils.py b/tests/test_extension_types/test_type_utils.py new file mode 100644 index 00000000..470b3512 --- /dev/null +++ b/tests/test_extension_types/test_type_utils.py @@ -0,0 +1,176 @@ +"""Tests for extension_types.type_utils helpers.""" + +from __future__ import annotations + +from typing import Optional, Union + +import pytest + +from orcapod.extension_types.type_utils import _extract_leaf_classes as extract_leaf_classes +from orcapod.extension_types.type_utils import _walk_fqcn + + +class _A: + pass + + +class _B: + pass + + +def test_plain_class(): + assert list(extract_leaf_classes(int)) == [int] + + +def test_plain_custom_class(): + assert list(extract_leaf_classes(_A)) == [_A] + + +def test_list_of_class(): + assert list(extract_leaf_classes(list[int])) == [int] + + +def test_dict_of_classes(): + result = set(extract_leaf_classes(dict[str, int])) + assert result == {str, int} + + +def test_optional_unwraps_none(): + """Optional[X] yields X but not NoneType.""" + result = list(extract_leaf_classes(Optional[int])) + assert result == [int] + + +def test_union_yields_all_non_none(): + result = set(extract_leaf_classes(Union[int, str])) + assert result == {int, str} + + +def test_union_with_none_excludes_none(): + result = set(extract_leaf_classes(Union[int, None])) + assert type(None) not in result + assert int in result + + +def test_nested_list_of_dict(): + """list[dict[_A, list[_B]]] yields _A and _B.""" + result = set(extract_leaf_classes(list[dict[_A, list[_B]]])) + assert result == {_A, _B} + + +def test_deeply_nested(): + """list[dict[str, list[dict[int, _A]]]] yields str, int, _A.""" + result = set(extract_leaf_classes(list[dict[str, list[dict[int, _A]]]])) + assert result == {str, int, _A} + + +def test_non_generic_non_type_is_skipped(): + """Annotations that are not types and not generic aliases yield nothing.""" + # e.g. a string annotation that failed resolution — should not crash + result = list(extract_leaf_classes("unresolved_string")) + assert result == [] + + +def test_none_type_plain(): + """type(None) itself yields type(None) as a leaf (not filtered at this level).""" + result = list(extract_leaf_classes(type(None))) + assert result == [type(None)] + + +# ── _walk_fqcn tests ───────────────────────────────────────────────────────── + + +def test_walk_fqcn_resolves_module_level_class(): + """_walk_fqcn resolves a top-level class from its FQCN.""" + import pathlib + obj = _walk_fqcn("pathlib.Path") + assert obj is pathlib.Path + + +def test_walk_fqcn_resolves_nested_attribute(): + """_walk_fqcn walks nested attribute chains (e.g. module.Outer.Inner).""" + import os.path + # os.path.join is a function reachable via attribute walk + obj = _walk_fqcn("os.path.join") + assert obj is os.path.join + + +def test_walk_fqcn_raises_import_error_on_bad_module(): + """_walk_fqcn raises ImportError when no module prefix can be imported.""" + with pytest.raises(ImportError): + _walk_fqcn("nonexistent.module.NoSuchClass") + + +def test_walk_fqcn_raises_import_error_on_missing_attr(): + """_walk_fqcn raises ImportError when module exists but attribute does not.""" + with pytest.raises(ImportError): + _walk_fqcn("pathlib.NoSuchClass") + + +def test_walk_fqcn_raises_import_error_on_single_part(): + """_walk_fqcn raises ImportError when FQCN has no module separator.""" + with pytest.raises(ImportError): + _walk_fqcn("justname") + + +def test_walk_fqcn_reraises_real_import_failure(monkeypatch): + """_walk_fqcn propagates ImportError from a module that exists but fails to import. + + Simulates the case where a module on disk raises ModuleNotFoundError for + one of its own optional dependencies (exc.name is the missing dep, not the + module being imported). The error must not be swallowed and replaced with + a generic "no valid module+attribute" ImportError. + """ + import importlib as _importlib + + original = _importlib.import_module + + def _patched(name: str, *args, **kwargs): + if name == "pathlib": + # "pathlib" exists but pretend it tries to import a missing dep. + err = ModuleNotFoundError("No module named 'some_optional_dep'") + err.name = "some_optional_dep" + raise err + return original(name, *args, **kwargs) + + monkeypatch.setattr(_importlib, "import_module", _patched) + + with pytest.raises(ModuleNotFoundError, match="some_optional_dep"): + _walk_fqcn("pathlib.Path") + + +def test_walk_fqcn_reraises_when_dep_name_is_bare_prefix_of_module(monkeypatch): + """_walk_fqcn does not swallow errors when exc.name is a bare substring of module_path. + + Regression: the old ``module_path.startswith(exc.name)`` check would + incorrectly swallow a ModuleNotFoundError for a dep named ``"path"`` while + importing ``"pathlib"``, because ``"pathlib".startswith("path")`` is True. + The fix requires an exact match or a dotted-prefix match. + """ + import importlib as _importlib + + original = _importlib.import_module + + def _patched(name: str, *args, **kwargs): + if name == "pathlib": + # dep name "path" is a bare prefix of "pathlib" — must not be swallowed. + err = ModuleNotFoundError("No module named 'path'") + err.name = "path" + raise err + return original(name, *args, **kwargs) + + monkeypatch.setattr(_importlib, "import_module", _patched) + + with pytest.raises(ModuleNotFoundError, match="'path'"): + _walk_fqcn("pathlib.Path") + + +# ── _import_from_fqcn tests ────────────────────────────────────────────────── + + +def test_import_from_fqcn_raises_for_non_dataclass(): + """_import_from_fqcn raises ImportError when FQCN resolves to a non-dataclass.""" + from orcapod.extension_types.dataclass_logical_type_factory import _import_from_fqcn + # pathlib.Path is importable via _walk_fqcn but is not a dataclass + with pytest.raises(ImportError): + _import_from_fqcn("pathlib.Path") diff --git a/tests/test_hashing/generate_hash_examples.py b/tests/test_hashing/generate_hash_examples.py index 5edbef3f..f9e58e7f 100644 --- a/tests/test_hashing/generate_hash_examples.py +++ b/tests/test_hashing/generate_hash_examples.py @@ -3,8 +3,7 @@ # throughout the tests to ensure consistent hashing behavior across different runs # and revisions of the codebase. # -# Uses the new BaseSemanticHasher API (get_default_semantic_hasher) rather than -# the legacy hash_to_hex / hash_to_int / hash_to_uuid functions. +# Uses SemanticAwarePythonHasher via get_default_semantic_hasher. import json from collections import OrderedDict @@ -27,7 +26,7 @@ def generate_hash_examples(): - """Generate hash examples for various data structures using BaseSemanticHasher.""" + """Generate hash examples for various data structures using ``SemanticAwarePythonHasher``.""" hasher = get_default_semantic_hasher() examples = [] diff --git a/tests/test_hashing/test_extension_type_hashing.py b/tests/test_hashing/test_extension_type_hashing.py new file mode 100644 index 00000000..72106df4 --- /dev/null +++ b/tests/test_hashing/test_extension_type_hashing.py @@ -0,0 +1,203 @@ +"""Tests for extension type column hashing via SemanticHashingVisitor.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest +from pathlib import Path + +from orcapod.hashing.visitors import SemanticHashingVisitor +from orcapod.contexts import get_default_context + + +@pytest.fixture +def ctx(): + return get_default_context() + + +class TestArrowTypeDataVisitorExtension: + def test_visit_dispatches_to_visit_extension_for_extension_types(self, ctx, tmp_path): + """visit() routes ExtensionType columns to visit_extension(), not visit_struct().""" + # Create a real file so visit_extension can complete without errors + real_file = tmp_path / "dummy.txt" + real_file.write_text("dispatch test") + + arrow_type = ctx.type_converter.register_python_class(Path) + assert isinstance(arrow_type, pa.ExtensionType), ( + "Path must be registered as an Arrow extension type" + ) + storage_val = ctx.type_converter.python_to_storage(Path(real_file), Path) + + calls = [] + + class TrackingVisitor(SemanticHashingVisitor): + def visit_extension(self, ext_type, storage_value): + calls.append("visit_extension") + return super().visit_extension(ext_type, storage_value) + + def visit_struct(self, struct_type, data): + calls.append("visit_struct") + return super().visit_struct(struct_type, data) + + visitor = TrackingVisitor(ctx.type_converter, ctx.semantic_hasher) + visitor.visit(arrow_type, storage_val) + assert "visit_extension" in calls + assert "visit_struct" not in calls + + +class TestSemanticHashingVisitorExtension: + def test_path_column_hashed_to_large_binary(self, ctx, tmp_path): + """Path extension columns are replaced with pa.large_binary() hash tokens.""" + file = tmp_path / "test.txt" + file.write_text("hello") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + new_type, new_data = visitor.visit(arrow_type, storage_val) + + assert new_type == pa.large_binary() + assert isinstance(new_data, bytes) + + def test_same_content_same_hash(self, ctx, tmp_path): + """Two paths pointing to files with identical content produce the same hash bytes.""" + file1 = tmp_path / "a.txt" + file2 = tmp_path / "b.txt" + file1.write_text("identical content") + file2.write_text("identical content") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage1 = ctx.type_converter.python_to_storage(Path(file1), Path) + storage2 = ctx.type_converter.python_to_storage(Path(file2), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash1 = visitor.visit(arrow_type, storage1) + _, hash2 = visitor.visit(arrow_type, storage2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self, ctx, tmp_path): + """Files with different content produce different hash bytes.""" + file1 = tmp_path / "x.txt" + file2 = tmp_path / "y.txt" + file1.write_text("content A") + file2.write_text("content B") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage1 = ctx.type_converter.python_to_storage(Path(file1), Path) + storage2 = ctx.type_converter.python_to_storage(Path(file2), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash1 = visitor.visit(arrow_type, storage1) + _, hash2 = visitor.visit(arrow_type, storage2) + + assert hash1 != hash2 + + def test_binary_encoding_format(self, ctx, tmp_path): + """Hash bytes have format b':::'.""" + file = tmp_path / "test.txt" + file.write_text("test") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, hash_bytes = visitor.visit(arrow_type, storage_val) + + assert b"::" in hash_bytes + type_prefix, hash_part = hash_bytes.split(b"::", 1) + # Extension name "orcapod.path" → dots replaced with colons + assert type_prefix == b"orcapod:path" + # hash_part should be "method:digest" — at least one colon + assert b":" in hash_part + + def test_null_value_passthrough(self, ctx): + """Null storage values pass through as-is.""" + arrow_type = ctx.type_converter.register_python_class(Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + new_type, new_data = visitor.visit(arrow_type, None) + + assert new_type == arrow_type + assert new_data is None + + def test_unregistered_python_type_passes_through(self, ctx): + """Extension types with no registered semantic hasher pass through unchanged.""" + import uuid + from orcapod.hashing.semantic_hashing.type_handler_registry import PythonTypeHandlerRegistry + from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher + + # Build a hasher with a registry that has NO entry for UUID + empty_registry = PythonTypeHandlerRegistry() + stripped_hasher = SemanticAwarePythonHasher( + hasher_id="test_v0", + type_handler_registry=empty_registry, + ) + + arrow_type = ctx.type_converter.register_python_class(uuid.UUID) + storage_val = ctx.type_converter.python_to_storage(uuid.UUID("12345678-1234-5678-1234-567812345678"), uuid.UUID) + + visitor = SemanticHashingVisitor(ctx.type_converter, stripped_hasher) + new_type, new_data = visitor.visit(arrow_type, storage_val) + + # Should be completely unchanged since UUID has no semantic hasher + assert new_type == arrow_type + assert new_data == storage_val + + +class TestCrossPathConsistency: + """Verify that the Arrow visitor path and the direct Python hasher path produce + identical hash tokens for the same underlying file content. + + The Arrow path (SemanticHashingVisitor.visit_extension) converts the extension + storage value back to a Python object and calls semantic_hasher.hash_object — + exactly the same call as the direct Python path. These tests make that + structural guarantee explicit and regression-proof. + + Hash encoding: + - Arrow path produces: b":::" + - Python path produces: ContentHash with to_prefixed_digest() → b":" + Stripping the type-name prefix from the Arrow encoding yields an identical + b":" byte string. + """ + + def test_arrow_and_semantic_hash_same_file_content(self, ctx, tmp_path): + """Arrow visitor path and direct Python hasher path embed the same digest.""" + file = tmp_path / "shared.txt" + file.write_text("shared content for both paths") + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file), Path) + + # Arrow path: visit_extension encodes as b":::" + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, arrow_hash_bytes = visitor.visit(arrow_type, storage_val) + # Strip the "orcapod:path::" type prefix to get b":" + prefixed_from_arrow = arrow_hash_bytes.split(b"::", 1)[1] + + # Python path: hash_object returns ContentHash directly + python_content_hash = ctx.semantic_hasher.hash_object(Path(file)) + prefixed_from_python = python_content_hash.to_prefixed_digest() + + assert prefixed_from_arrow == prefixed_from_python + + def test_same_content_two_files_cross_path(self, ctx, tmp_path): + """Two files with identical content: Arrow path and Python path agree.""" + file_arrow = tmp_path / "file_arrow.txt" + file_python = tmp_path / "file_python.txt" + content = "same content for cross-path test" + file_arrow.write_text(content) + file_python.write_text(content) + + arrow_type = ctx.type_converter.register_python_class(Path) + storage_val = ctx.type_converter.python_to_storage(Path(file_arrow), Path) + + visitor = SemanticHashingVisitor(ctx.type_converter, ctx.semantic_hasher) + _, arrow_hash_bytes = visitor.visit(arrow_type, storage_val) + prefixed_from_arrow = arrow_hash_bytes.split(b"::", 1)[1] + + python_content_hash = ctx.semantic_hasher.hash_object(Path(file_python)) + prefixed_from_python = python_content_hash.to_prefixed_digest() + + assert prefixed_from_arrow == prefixed_from_python diff --git a/tests/test_hashing/test_file_hashing_consistency.py b/tests/test_hashing/test_file_hashing_consistency.py deleted file mode 100644 index e5bd4bbf..00000000 --- a/tests/test_hashing/test_file_hashing_consistency.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -Integration tests verifying that file hashing is consistent across both paths: - -1. **Arrow hasher path**: SemanticArrowHasher processes an Arrow table containing a - path struct column → calls PythonPathStructConverter.hash_struct_dict → file_hasher. -2. **Semantic hasher path**: BaseSemanticHasher hashes a Python Path object → - calls PathContentHandler.handle → file_hasher. - -Both paths must delegate to the same FileContentHasherProtocol so that identical -file content always produces identical hashes, regardless of entry point. -""" - -from pathlib import Path - -import pyarrow as pa -import pytest - -from orcapod.hashing.arrow_hashers import SemanticArrowHasher -from orcapod.hashing.file_hashers import BasicFileHasher -from orcapod.hashing.semantic_hashing.builtin_handlers import ( - register_builtin_handlers, -) -from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher -from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry -from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry -from orcapod.semantic_types.semantic_struct_converters import PythonPathStructConverter - - -# --------------------------------------------------------------------------- -# Shared fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def file_hasher(): - """Single file hasher instance shared by both paths.""" - return BasicFileHasher(algorithm="sha256") - - -@pytest.fixture -def path_converter(file_hasher): - return PythonPathStructConverter(file_hasher=file_hasher) - - -@pytest.fixture -def arrow_hasher(path_converter): - """SemanticArrowHasher wired with the shared file_hasher via PythonPathStructConverter.""" - registry = SemanticTypeRegistry() - registry.register_converter("path", path_converter) - return SemanticArrowHasher(semantic_registry=registry) - - -@pytest.fixture -def semantic_hasher(file_hasher): - """BaseSemanticHasher wired with the shared file_hasher via PathContentHandler.""" - registry = TypeHandlerRegistry() - register_builtin_handlers(registry, file_hasher=file_hasher) - return BaseSemanticHasher( - hasher_id="test_v1", type_handler_registry=registry, strict=True - ) - - -# --------------------------------------------------------------------------- -# Arrow struct hasher: path column tests -# --------------------------------------------------------------------------- - - -class TestArrowStructPathHashing: - """Tests for file hashing through the Arrow hasher path.""" - - def test_same_content_different_paths_same_hash(self, arrow_hasher, tmp_path): - """Two distinct files with identical content produce the same table hash.""" - file1 = tmp_path / "a.txt" - file2 = tmp_path / "b.txt" - file1.write_text("identical content") - file2.write_text("identical content") - - table1 = pa.table( - {"file": [{"path": str(file1)}]}, - schema=pa.schema( - [pa.field("file", pa.struct([pa.field("path", pa.large_string())]))] - ), - ) - table2 = pa.table( - {"file": [{"path": str(file2)}]}, - schema=pa.schema( - [pa.field("file", pa.struct([pa.field("path", pa.large_string())]))] - ), - ) - - hash1 = arrow_hasher.hash_table(table1) - hash2 = arrow_hasher.hash_table(table2) - assert hash1.digest == hash2.digest - - def test_modified_content_different_hash(self, arrow_hasher, tmp_path): - """Same path with modified content between hashes yields different hash.""" - file = tmp_path / "mutable.txt" - file.write_text("version 1") - - schema = pa.schema( - [pa.field("file", pa.struct([pa.field("path", pa.large_string())]))] - ) - table_v1 = pa.table({"file": [{"path": str(file)}]}, schema=schema) - hash1 = arrow_hasher.hash_table(table_v1) - - file.write_text("version 2") - table_v2 = pa.table({"file": [{"path": str(file)}]}, schema=schema) - hash2 = arrow_hasher.hash_table(table_v2) - - assert hash1.digest != hash2.digest - - def test_different_content_different_hash(self, arrow_hasher, tmp_path): - """Two files with different content produce different table hashes.""" - file1 = tmp_path / "x.txt" - file2 = tmp_path / "y.txt" - file1.write_text("content A") - file2.write_text("content B") - - schema = pa.schema( - [pa.field("file", pa.struct([pa.field("path", pa.large_string())]))] - ) - table1 = pa.table({"file": [{"path": str(file1)}]}, schema=schema) - table2 = pa.table({"file": [{"path": str(file2)}]}, schema=schema) - - hash1 = arrow_hasher.hash_table(table1) - hash2 = arrow_hasher.hash_table(table2) - assert hash1.digest != hash2.digest - - -# --------------------------------------------------------------------------- -# Semantic hasher: Path object tests -# --------------------------------------------------------------------------- - - -class TestSemanticPathHashing: - """Tests for file hashing through the semantic hasher path.""" - - def test_same_content_different_paths_same_hash(self, semantic_hasher, tmp_path): - """Two distinct Path objects pointing to files with identical content.""" - file1 = tmp_path / "a.txt" - file2 = tmp_path / "b.txt" - file1.write_text("identical content") - file2.write_text("identical content") - - hash1 = semantic_hasher.hash_object(Path(file1)) - hash2 = semantic_hasher.hash_object(Path(file2)) - assert hash1.digest == hash2.digest - - def test_modified_content_different_hash(self, semantic_hasher, tmp_path): - """Same Path with modified content between hashes.""" - file = tmp_path / "mutable.txt" - file.write_text("version 1") - hash1 = semantic_hasher.hash_object(Path(file)) - - file.write_text("version 2") - hash2 = semantic_hasher.hash_object(Path(file)) - assert hash1.digest != hash2.digest - - def test_different_content_different_hash(self, semantic_hasher, tmp_path): - """Two Paths pointing to different content produce different hashes.""" - file1 = tmp_path / "x.txt" - file2 = tmp_path / "y.txt" - file1.write_text("content A") - file2.write_text("content B") - - hash1 = semantic_hasher.hash_object(Path(file1)) - hash2 = semantic_hasher.hash_object(Path(file2)) - assert hash1.digest != hash2.digest - - -# --------------------------------------------------------------------------- -# Cross-path consistency -# --------------------------------------------------------------------------- - - -class TestCrossPathConsistency: - """Verify that the arrow hasher and semantic hasher use the same file_hasher - and produce equivalent file content hashes for the same underlying file.""" - - def test_arrow_and_semantic_hash_same_file_content( - self, path_converter, semantic_hasher, file_hasher, tmp_path - ): - """The file content hash extracted by PythonPathStructConverter.hash_struct_dict - must embed the same digest as ContentHash produced by PathContentHandler.handle - (which the semantic hasher uses internally for Path objects). - - Both paths ultimately call file_hasher.hash_file(path), so the raw digest - must be identical. hash_struct_dict always returns the fully-prefixed form - "path:sha256:", so we strip the prefix when comparing. - """ - file = tmp_path / "shared.txt" - file.write_text("shared content for both paths") - - # Arrow path: PythonPathStructConverter.hash_struct_dict — always prefixed - arrow_hash = path_converter.hash_struct_dict({"path": str(file)}) - # Strip "path:sha256:" prefix to get the raw hex - arrow_hash_hex = arrow_hash.split(":")[-1] - - # Semantic path: file_hasher.hash_file directly (same as PathContentHandler) - semantic_content_hash = file_hasher.hash_file(file) - - assert arrow_hash_hex == semantic_content_hash.digest.hex() - - def test_arrow_and_semantic_same_content_two_files( - self, path_converter, file_hasher, tmp_path - ): - """Two files with identical content: arrow struct hash_struct_dict and - direct file_hasher.hash_file produce the same digest.""" - file1 = tmp_path / "file_arrow.txt" - file2 = tmp_path / "file_semantic.txt" - content = "same content for cross-path test" - file1.write_text(content) - file2.write_text(content) - - # hash_struct_dict always returns "path:sha256:" — strip prefix - arrow_hex = path_converter.hash_struct_dict({"path": str(file1)}).split(":")[-1] - semantic_hex = file_hasher.hash_file(file2).digest.hex() - - assert arrow_hex == semantic_hex diff --git a/tests/test_hashing/test_hash_samples.py b/tests/test_hashing/test_hash_samples.py index 4caff744..b255f818 100644 --- a/tests/test_hashing/test_hash_samples.py +++ b/tests/test_hashing/test_hash_samples.py @@ -1,7 +1,7 @@ """ Tests for hash samples consistency. -Verifies that BaseSemanticHasher produces identical hashes across runs for a +Verifies that SemanticAwarePythonHasher produces identical hashes across runs for a fixed set of recorded input values. The sample file is generated (or regenerated) by running generate_hash_examples.py. diff --git a/tests/test_hashing/test_semantic_hasher.py b/tests/test_hashing/test_semantic_hasher.py index b2719b4a..3fe4fd38 100644 --- a/tests/test_hashing/test_semantic_hasher.py +++ b/tests/test_hashing/test_semantic_hasher.py @@ -1,17 +1,17 @@ """ -Comprehensive test suite for the BaseSemanticHasher system. +Comprehensive test suite for the SemanticAwarePythonHasher system. Covers: - - BaseSemanticHasher: primitives, container type-tagging, determinism, + - SemanticAwarePythonHasher: primitives, container type-tagging, determinism, circular references, strict vs non-strict mode - ContentIdentifiableProtocol protocol: independent hashing, composability - - TypeHandlerRegistry: registration, MRO-aware lookup, unregister - - Built-in handlers: bytes, UUID, Path, functions, type objects + - PythonTypeHandlerRegistry: registration, MRO-aware lookup, unregister + - Built-in hashers: bytes, UUID, Path, functions, type objects - ContentHash as terminal: returned as-is without re-hashing - ContentIdentifiableMixin: content_hash, __eq__, __hash__, caching, cache invalidation, injectable hasher - - Custom type handler registration and extension - - get_default_semantic_hasher / get_default_type_handler_registry + - Custom type hasher registration and extension + - get_default_semantic_hasher / get_default_python_type_handler_registry """ from __future__ import annotations @@ -27,17 +27,19 @@ import pytest from orcapod.hashing.defaults import get_default_semantic_hasher -from orcapod.hashing.semantic_hashing.builtin_handlers import register_builtin_handlers +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + register_builtin_python_type_handlers, +) from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( ContentIdentifiableMixin, ) from orcapod.hashing.semantic_hashing.semantic_hasher import ( - BaseSemanticHasher, + SemanticAwarePythonHasher, _is_namedtuple, ) from orcapod.hashing.semantic_hashing.type_handler_registry import ( - TypeHandlerRegistry, - get_default_type_handler_registry, + PythonTypeHandlerRegistry, + get_default_python_type_handler_registry, ) from orcapod.types import ContentHash @@ -46,22 +48,22 @@ # --------------------------------------------------------------------------- -def make_hasher(strict: bool = True) -> BaseSemanticHasher: - """Create a fresh BaseSemanticHasher with an isolated registry.""" - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) - return BaseSemanticHasher( +def make_hasher(strict: bool = True) -> SemanticAwarePythonHasher: + """Create a fresh SemanticAwarePythonHasher with an isolated registry.""" + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) + return SemanticAwarePythonHasher( hasher_id="test_v1", type_handler_registry=registry, strict=strict ) @pytest.fixture -def hasher() -> BaseSemanticHasher: +def hasher() -> SemanticAwarePythonHasher: return make_hasher(strict=True) @pytest.fixture -def lenient_hasher() -> BaseSemanticHasher: +def lenient_hasher() -> SemanticAwarePythonHasher: return make_hasher(strict=False) @@ -108,7 +110,7 @@ def identity_structure(self) -> Any: # --------------------------------------------------------------------------- -# 1. BaseSemanticHasher: primitives +# 1. SemanticAwarePythonHasher: primitives # --------------------------------------------------------------------------- @@ -152,7 +154,7 @@ def test_same_primitive_same_hash(self, hasher): # --------------------------------------------------------------------------- -# 2. BaseSemanticHasher: container type-tagging and determinism +# 2. SemanticAwarePythonHasher: container type-tagging and determinism # --------------------------------------------------------------------------- @@ -213,7 +215,7 @@ def test_hash_returns_content_hash(self, hasher): # --------------------------------------------------------------------------- -# 3. BaseSemanticHasher: namedtuples +# 3. SemanticAwarePythonHasher: namedtuples # --------------------------------------------------------------------------- @@ -249,7 +251,7 @@ def test_is_namedtuple_helper(self): # --------------------------------------------------------------------------- -# 4. BaseSemanticHasher: circular references +# 4. SemanticAwarePythonHasher: circular references # --------------------------------------------------------------------------- @@ -284,7 +286,7 @@ def test_circular_differs_from_non_circular(self, hasher): # --------------------------------------------------------------------------- -# 5. BaseSemanticHasher: strict vs non-strict mode +# 5. SemanticAwarePythonHasher: strict vs non-strict mode # --------------------------------------------------------------------------- @@ -297,7 +299,7 @@ def __init__(self, x: int) -> None: class TestStrictMode: def test_strict_raises_on_unknown_type(self, hasher): - with pytest.raises(TypeError, match="no TypeHandlerProtocol registered"): + with pytest.raises(TypeError, match="no implementation of PythonTypeHandlerProtocol registered"): hasher.hash_object(Unhandled(1)) def test_non_strict_returns_content_hash(self, lenient_hasher): @@ -310,8 +312,8 @@ def test_non_strict_same_object_same_hash(self, lenient_hasher): assert h1 == h2 def test_strict_mode_flag(self): - strict = BaseSemanticHasher(hasher_id="s", strict=True) - lenient = BaseSemanticHasher(hasher_id="s", strict=False) + strict = SemanticAwarePythonHasher(hasher_id="s", strict=True) + lenient = SemanticAwarePythonHasher(hasher_id="s", strict=False) assert strict.strict is True assert lenient.strict is False @@ -795,7 +797,7 @@ def test_usable_in_set(self, hasher): assert len(s) == 2 def test_injectable_hasher(self): - custom_hasher = BaseSemanticHasher(hasher_id="injected_v9") + custom_hasher = SemanticAwarePythonHasher(hasher_id="injected_v9") rec = SimpleRecord("foo", 1, semantic_hasher=custom_hasher) assert rec.content_hash().method == "injected_v9" @@ -820,15 +822,16 @@ def test_repr_includes_hash(self, hasher): # --------------------------------------------------------------------------- -# 14. TypeHandlerRegistry +# 14. PythonTypeHandlerRegistry # --------------------------------------------------------------------------- -class _DummyHandler: +class _DummySemanticHasher: def __init__(self, tag: str) -> None: self.tag = tag def handle(self, obj: Any, hasher: Any) -> Any: + # Returns a representative Python structure; outer hasher performs final hashing return f"{self.tag}:{obj}" @@ -844,95 +847,95 @@ class GrandChild(Child): pass -class TestTypeHandlerRegistry: +class TestPythonTypeHandlerRegistry: def test_register_and_get_exact(self): - reg = TypeHandlerRegistry() - h = _DummyHandler("base") + reg = PythonTypeHandlerRegistry() + h = _DummySemanticHasher("base") reg.register(Base, h) assert reg.get_handler(Base()) is h def test_mro_lookup_child(self): - reg = TypeHandlerRegistry() - h = _DummyHandler("base") + reg = PythonTypeHandlerRegistry() + h = _DummySemanticHasher("base") reg.register(Base, h) assert reg.get_handler(Child()) is h def test_mro_lookup_grandchild(self): - reg = TypeHandlerRegistry() - h = _DummyHandler("base") + reg = PythonTypeHandlerRegistry() + h = _DummySemanticHasher("base") reg.register(Base, h) assert reg.get_handler(GrandChild()) is h def test_more_specific_handler_wins(self): - reg = TypeHandlerRegistry() - h_base = _DummyHandler("base") - h_child = _DummyHandler("child") + reg = PythonTypeHandlerRegistry() + h_base = _DummySemanticHasher("base") + h_child = _DummySemanticHasher("child") reg.register(Base, h_base) reg.register(Child, h_child) assert reg.get_handler(Child()) is h_child assert reg.get_handler(GrandChild()) is h_child def test_unregistered_returns_none(self): - reg = TypeHandlerRegistry() + reg = PythonTypeHandlerRegistry() assert reg.get_handler(Base()) is None def test_unregister_removes_handler(self): - reg = TypeHandlerRegistry() - h = _DummyHandler("base") + reg = PythonTypeHandlerRegistry() + h = _DummySemanticHasher("base") reg.register(Base, h) assert reg.unregister(Base) is True assert reg.get_handler(Base()) is None def test_unregister_nonexistent_returns_false(self): - reg = TypeHandlerRegistry() + reg = PythonTypeHandlerRegistry() assert reg.unregister(Base) is False def test_replace_existing_handler(self): - reg = TypeHandlerRegistry() - h1 = _DummyHandler("first") - h2 = _DummyHandler("second") + reg = PythonTypeHandlerRegistry() + h1 = _DummySemanticHasher("first") + h2 = _DummySemanticHasher("second") reg.register(Base, h1) reg.register(Base, h2) assert reg.get_handler(Base()) is h2 def test_register_non_type_raises(self): - reg = TypeHandlerRegistry() + reg = PythonTypeHandlerRegistry() with pytest.raises(TypeError): - reg.register("not_a_type", _DummyHandler("x")) # type: ignore[arg-type] + reg.register("not_a_type", _DummySemanticHasher("x")) # type: ignore[arg-type] def test_has_handler_exact(self): - reg = TypeHandlerRegistry() - reg.register(Base, _DummyHandler("b")) + reg = PythonTypeHandlerRegistry() + reg.register(Base, _DummySemanticHasher("b")) assert reg.has_handler(Base) is True def test_has_handler_via_mro(self): - reg = TypeHandlerRegistry() - reg.register(Base, _DummyHandler("b")) + reg = PythonTypeHandlerRegistry() + reg.register(Base, _DummySemanticHasher("b")) assert reg.has_handler(Child) is True def test_has_handler_false(self): - reg = TypeHandlerRegistry() + reg = PythonTypeHandlerRegistry() assert reg.has_handler(Base) is False def test_registered_types_snapshot(self): - reg = TypeHandlerRegistry() - reg.register(Base, _DummyHandler("b")) - reg.register(Child, _DummyHandler("c")) + reg = PythonTypeHandlerRegistry() + reg.register(Base, _DummySemanticHasher("b")) + reg.register(Child, _DummySemanticHasher("c")) types = reg.registered_types() assert Base in types assert Child in types def test_len(self): - reg = TypeHandlerRegistry() + reg = PythonTypeHandlerRegistry() assert len(reg) == 0 - reg.register(Base, _DummyHandler("b")) + reg.register(Base, _DummySemanticHasher("b")) assert len(reg) == 1 - reg.register(Child, _DummyHandler("c")) + reg.register(Child, _DummySemanticHasher("c")) assert len(reg) == 2 def test_get_handler_for_type(self): - reg = TypeHandlerRegistry() - h = _DummyHandler("b") + reg = PythonTypeHandlerRegistry() + h = _DummySemanticHasher("b") reg.register(Base, h) assert reg.get_handler_for_type(Base) is h assert reg.get_handler_for_type(Child) is h # via MRO @@ -956,19 +959,19 @@ def handle(self, obj: Any, hasher: Any) -> Any: class TestCustomHandlerRegistration: def test_register_custom_type(self): - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, CelsiusHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry, strict=True ) assert isinstance(custom_hasher.hash_object(Celsius(100.0)), ContentHash) def test_custom_handler_determinism(self): - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, CelsiusHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry ) h1 = custom_hasher.hash_object(Celsius(37.5)) @@ -976,10 +979,10 @@ def test_custom_handler_determinism(self): assert h1 == h2 def test_custom_handler_different_values_differ(self): - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, CelsiusHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry ) assert custom_hasher.hash_object(Celsius(0.0)) != custom_hasher.hash_object( @@ -987,15 +990,15 @@ def test_custom_handler_different_values_differ(self): ) def test_unregistered_type_still_strict(self): - hasher = BaseSemanticHasher(hasher_id="strict_v1", strict=True) + hasher = SemanticAwarePythonHasher(hasher_id="strict_v1", strict=True) with pytest.raises(TypeError): hasher.hash_object(Celsius(42.0)) def test_custom_handler_in_nested_structure(self): - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, CelsiusHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry ) h = custom_hasher.hash_object({"temp": Celsius(36.6), "unit": "C"}) @@ -1008,10 +1011,10 @@ class DirectHashHandler: def handle(self, obj: Any, hasher: Any) -> ContentHash: return ContentHash("direct", b"\xaa" * 32) - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, DirectHashHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry ) result = custom_hasher.hash_object(Celsius(0.0)) @@ -1022,10 +1025,10 @@ def test_mro_aware_custom_handler(self): class FancyCelsius(Celsius): pass - registry = TypeHandlerRegistry() - register_builtin_handlers(registry) + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) registry.register(Celsius, CelsiusHandler()) - custom_hasher = BaseSemanticHasher( + custom_hasher = SemanticAwarePythonHasher( hasher_id="custom_v1", type_handler_registry=registry ) h = custom_hasher.hash_object(FancyCelsius(20.0)) @@ -1042,7 +1045,7 @@ class KelvinHandler: def handle(self, obj: Any, hasher: Any) -> Any: return {"__type__": "Kelvin", "k": obj.k} - global_registry = get_default_type_handler_registry() + global_registry = get_default_python_type_handler_registry() global_registry.register(Kelvin, KelvinHandler()) try: default_hasher = get_default_semantic_hasher() @@ -1058,14 +1061,14 @@ def handle(self, obj: Any, hasher: Any) -> Any: class TestGlobalSingletons: def test_get_default_semantic_hasher_returns_semantic_hasher(self): - assert isinstance(get_default_semantic_hasher(), BaseSemanticHasher) + assert isinstance(get_default_semantic_hasher(), SemanticAwarePythonHasher) def test_get_default_semantic_hasher_has_versioned_id(self): assert get_default_semantic_hasher().hasher_id == "semantic_v0.1" def test_get_default_type_handler_registry_is_singleton(self): - r1 = get_default_type_handler_registry() - r2 = get_default_type_handler_registry() + r1 = get_default_python_type_handler_registry() + r2 = get_default_python_type_handler_registry() assert r1 is r2 def test_default_registry_has_builtin_handlers(self): @@ -1073,7 +1076,7 @@ def test_default_registry_has_builtin_handlers(self): import typing as _typing - reg = get_default_type_handler_registry() + reg = get_default_python_type_handler_registry() assert reg.has_handler(bytes) assert reg.has_handler(bytearray) assert reg.has_handler(UUID) @@ -1087,7 +1090,7 @@ def test_default_registry_has_builtin_handlers(self): def test_default_registry_has_no_content_hash_handler(self): """ContentHash is handled as a terminal -- no registry entry needed.""" - reg = get_default_type_handler_registry() + reg = get_default_python_type_handler_registry() assert not reg.has_handler(ContentHash) def test_default_hasher_can_hash_common_types(self): @@ -1118,7 +1121,7 @@ def test_content_hash_conversion_methods(self): def _sha256_json(obj: Any, hasher_id: str) -> "ContentHash": - """Manually JSON-serialize *obj* with the same settings as BaseSemanticHasher + """Manually JSON-serialize *obj* with the same settings as SemanticAwarePythonHasher and return the resulting ContentHash.""" json_bytes = json.dumps( obj, @@ -1134,7 +1137,7 @@ class TestJsonNormalizationConsistency: """Verify that hash_object produces hashes identical to directly SHA-256 hashing the canonical tagged-JSON form that _expand_structure produces. - These tests treat BaseSemanticHasher as a black box and anchor its output to + These tests treat SemanticAwarePythonHasher as a black box and anchor its output to a human-verifiable serialization format, ensuring the algorithm is transparent and reproducible without the library. """ @@ -1142,7 +1145,7 @@ class TestJsonNormalizationConsistency: HASHER_ID = "test_v1" @pytest.fixture - def h(self) -> BaseSemanticHasher: + def h(self) -> SemanticAwarePythonHasher: return make_hasher(strict=True) # ------------------------------------------------------------------ @@ -1284,7 +1287,7 @@ def test_no_resolver_uses_obj_content_hash(self): """Without a resolver hash_object returns obj.content_hash() -- using the object's own hasher.""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="obj_hasher_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="obj_hasher_v1") rec = SimpleRecord("hello", 1, semantic_hasher=obj_hasher) result = calling_hasher.hash_object(rec) @@ -1294,7 +1297,7 @@ def test_no_resolver_uses_obj_content_hash(self): def test_resolver_overrides_default(self): """When a resolver is provided it takes priority over obj.content_hash().""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="obj_hasher_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="obj_hasher_v1") rec = SimpleRecord("hello", 1, semantic_hasher=obj_hasher) # Resolver that uses the calling hasher instead of the object's own hasher @@ -1307,7 +1310,7 @@ def test_resolver_overrides_default(self): def test_resolver_differs_from_no_resolver_when_hashers_differ(self): """When the object's hasher differs from the calling hasher, resolver and no-resolver produce different results.""" - obj_hasher = BaseSemanticHasher(hasher_id="obj_v99") + obj_hasher = SemanticAwarePythonHasher(hasher_id="obj_v99") calling_hasher = make_hasher(strict=True) rec = SimpleRecord("data", 42, semantic_hasher=obj_hasher) @@ -1324,7 +1327,7 @@ def test_resolver_differs_from_no_resolver_when_hashers_differ(self): def test_resolver_propagates_through_list(self): """Resolver is applied to CI objects nested inside a list.""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="inner_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="inner_v1") inner = SimpleRecord("inner", 99, semantic_hasher=obj_hasher) # With no resolver the embedded token uses inner's own hasher_id @@ -1344,7 +1347,7 @@ def test_resolver_propagates_through_list(self): def test_resolver_propagates_through_tuple(self): """Resolver is applied to CI objects nested inside a tuple.""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="inner_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="inner_v1") inner = SimpleRecord("x", 1, semantic_hasher=obj_hasher) resolver = lambda obj: calling_hasher.hash_object(obj.identity_structure()) @@ -1356,7 +1359,7 @@ def test_resolver_propagates_through_tuple(self): def test_resolver_propagates_through_dict(self): """Resolver is applied to CI objects nested inside a dict value.""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="inner_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="inner_v1") inner = SimpleRecord("v", 2, semantic_hasher=obj_hasher) resolver = lambda obj: calling_hasher.hash_object(obj.identity_structure()) @@ -1388,7 +1391,7 @@ def test_resolver_propagates_through_handler_result(self): """When a registered handler returns a ContentIdentifiable, the resolver is applied to that result.""" calling_hasher = make_hasher(strict=True) - obj_hasher = BaseSemanticHasher(hasher_id="inner_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="inner_v1") inner = SimpleRecord("inner", 5, semantic_hasher=obj_hasher) resolved = [] @@ -1406,7 +1409,7 @@ def resolver(obj): def test_cached_result_reused_across_calls(self): """content_hash() caches by hasher_id -- the same ContentHash object is returned on repeated calls with the same hasher.""" - obj_hasher = BaseSemanticHasher(hasher_id="cached_v1") + obj_hasher = SemanticAwarePythonHasher(hasher_id="cached_v1") rec = SimpleRecord("y", 5, semantic_hasher=obj_hasher) first = rec.content_hash() @@ -1432,8 +1435,8 @@ class TestUniformHasherPropagation: def test_entry_point_hasher_overrides_nested_hasher(self): """outer.content_hash() uses outer's hasher for inner, even though inner holds a different hasher.""" - hasher_a = BaseSemanticHasher(hasher_id="hasher_a") - hasher_b = BaseSemanticHasher(hasher_id="hasher_b") + hasher_a = SemanticAwarePythonHasher(hasher_id="hasher_a") + hasher_b = SemanticAwarePythonHasher(hasher_id="hasher_b") inner = SimpleRecord("inner", 1, semantic_hasher=hasher_a) outer = NestedRecord("outer", inner, semantic_hasher=hasher_b) @@ -1461,9 +1464,9 @@ def test_entry_point_hasher_overrides_nested_hasher(self): def test_three_level_chain_uses_entry_hasher_throughout(self): """In a three-level chain A→B→C, calling C.content_hash() uses C's hasher for A and B as well, even though each holds a different hasher.""" - hasher_a = BaseSemanticHasher(hasher_id="hasher_a") - hasher_b = BaseSemanticHasher(hasher_id="hasher_b") - hasher_c = BaseSemanticHasher(hasher_id="hasher_c") + hasher_a = SemanticAwarePythonHasher(hasher_id="hasher_a") + hasher_b = SemanticAwarePythonHasher(hasher_id="hasher_b") + hasher_c = SemanticAwarePythonHasher(hasher_id="hasher_c") a = SimpleRecord("a", 1, semantic_hasher=hasher_a) b = NestedRecord("b", a, semantic_hasher=hasher_b) @@ -1494,8 +1497,8 @@ def test_three_level_chain_uses_entry_hasher_throughout(self): def test_independent_call_still_uses_own_hasher(self): """When an intermediate object is called directly (not as part of a larger chain), it uses its own stored hasher as before.""" - hasher_a = BaseSemanticHasher(hasher_id="hasher_a") - hasher_b = BaseSemanticHasher(hasher_id="hasher_b") + hasher_a = SemanticAwarePythonHasher(hasher_id="hasher_a") + hasher_b = SemanticAwarePythonHasher(hasher_id="hasher_b") inner = SimpleRecord("inner", 1, semantic_hasher=hasher_a) outer = NestedRecord("outer", inner, semantic_hasher=hasher_b) @@ -1507,8 +1510,8 @@ def test_independent_call_still_uses_own_hasher(self): def test_cache_keyed_by_hasher_id_avoids_recomputation(self): """The cache is keyed by hasher_id, so a nested object computed under hasher_c is cached and reused on a second call with hasher_c.""" - hasher_a = BaseSemanticHasher(hasher_id="hasher_a") - hasher_c = BaseSemanticHasher(hasher_id="hasher_c") + hasher_a = SemanticAwarePythonHasher(hasher_id="hasher_a") + hasher_c = SemanticAwarePythonHasher(hasher_id="hasher_c") inner = SimpleRecord("inner", 42, semantic_hasher=hasher_a) diff --git a/tests/test_hashing/test_starfix_arrow_hasher.py b/tests/test_hashing/test_starfix_arrow_hasher.py index 77e52f76..4734e436 100644 --- a/tests/test_hashing/test_starfix_arrow_hasher.py +++ b/tests/test_hashing/test_starfix_arrow_hasher.py @@ -32,7 +32,6 @@ _CURRENT_ARROW_HASHER_ID, get_versioned_semantic_arrow_hasher, ) -from orcapod.semantic_types import SemanticTypeRegistry from orcapod.types import ContentHash @@ -46,8 +45,11 @@ def _make_hasher() -> StarfixArrowHasher: + from orcapod.contexts import get_default_context + ctx = get_default_context() return StarfixArrowHasher( - semantic_registry=SemanticTypeRegistry(), + type_converter=ctx.type_converter, + semantic_hasher=ctx.semantic_hasher, hasher_id=HASHER_ID, ) diff --git a/tests/test_hashing/test_uuid_handler.py b/tests/test_hashing/test_uuid_handler.py index 8b69d78b..a4692510 100644 --- a/tests/test_hashing/test_uuid_handler.py +++ b/tests/test_hashing/test_uuid_handler.py @@ -1,32 +1,51 @@ -"""Tests for UUIDHandler low-level handle() method behaviour. +"""Tests for UUIDHandler handle() dispatch via SemanticAwarePythonHasher. -Verifies that UUIDHandler returns the 16-byte binary representation of a -UUID, consistent with OrcaPod's canonical ``pa.binary(16)`` Arrow storage -format. +Verifies that UUIDHandler produces a ContentHash based on the 16-byte +binary representation of a UUID, consistent with OrcaPod's canonical +``pa.binary(16)`` Arrow storage format. """ from __future__ import annotations import uuid as _uuid +from orcapod.hashing.semantic_hashing.semantic_hasher import SemanticAwarePythonHasher +from orcapod.types import ContentHash -def test_uuid_handler_returns_bytes(): - """UUIDHandler should return the 16-byte binary representation.""" - from orcapod.hashing.semantic_hashing.builtin_handlers import UUIDHandler - handler = UUIDHandler() +def _make_hasher() -> SemanticAwarePythonHasher: + from orcapod.hashing.semantic_hashing.builtin_handlers import ( + register_builtin_python_type_handlers, + ) + from orcapod.hashing.semantic_hashing.type_handler_registry import ( + PythonTypeHandlerRegistry, + ) + + registry = PythonTypeHandlerRegistry() + register_builtin_python_type_handlers(registry) + return SemanticAwarePythonHasher( + hasher_id="test_v1", type_handler_registry=registry, strict=True + ) + + +def test_uuid_handler_returns_content_hash(): + """UUIDHandler should return a ContentHash for a UUID.""" + hasher = _make_hasher() u = _uuid.UUID("550e8400-e29b-41d4-a716-446655440000") - result = handler.handle(u, hasher=None) # type: ignore[arg-type] - assert result == u.bytes - assert isinstance(result, bytes) - assert len(result) == 16 + result = hasher.hash_object(u) + assert isinstance(result, ContentHash) -def test_uuid_handler_different_uuids_produce_different_bytes(): - """Different UUID values must produce different byte sequences.""" - from orcapod.hashing.semantic_hashing.builtin_handlers import UUIDHandler +def test_uuid_handler_same_uuid_same_hash(): + """Same UUID value produces the same ContentHash.""" + hasher = _make_hasher() + u = _uuid.UUID("550e8400-e29b-41d4-a716-446655440000") + assert hasher.hash_object(u) == hasher.hash_object(u) + - handler = UUIDHandler() +def test_uuid_handler_different_uuids_produce_different_hashes(): + """Different UUID values must produce different ContentHash objects.""" + hasher = _make_hasher() u1 = _uuid.uuid4() u2 = _uuid.uuid4() - assert handler.handle(u1, None) != handler.handle(u2, None) # type: ignore[arg-type] + assert hasher.hash_object(u1) != hasher.hash_object(u2) diff --git a/tests/test_semantic_types/test_dataclass_encoding.py b/tests/test_semantic_types/test_dataclass_encoding.py deleted file mode 100644 index b0f34ba8..00000000 --- a/tests/test_semantic_types/test_dataclass_encoding.py +++ /dev/null @@ -1,804 +0,0 @@ -# tests/test_semantic_types/test_dataclass_encoding.py -from __future__ import annotations - -import dataclasses -import os -import tempfile -import typing -from unittest.mock import MagicMock, patch - -import pyarrow as pa -import pytest - -from orcapod.semantic_types.dataclass_encoding import ( - DATACLASS_TYPE_FIELD, - DATACLASS_TYPE_PREFIX, - _DATACLASS_REGISTRY, - dataclass_to_arrow_struct_type, - dataclass_to_struct_dict, - has_dataclass_type_sentinel, - register_dataclass, - struct_dict_to_dataclass, -) -import orcapod.semantic_types.dataclass_encoding as _dc_enc -from orcapod.semantic_types.universal_converter import UniversalTypeConverter -from orcapod.types import Schema - - -@dataclasses.dataclass -class _Simple: - a: int - b: str - - -def test_constants(): - assert DATACLASS_TYPE_FIELD == "__dataclass." - assert DATACLASS_TYPE_PREFIX == "dataclass:" - - -def test_register_explicit(): - register_dataclass(_Simple) - key = f"{_Simple.__module__}.{_Simple.__qualname__}" - assert _DATACLASS_REGISTRY[key] is _Simple - - -def test_register_returns_class(): - result = register_dataclass(_Simple) - assert result is _Simple - - -def test_register_as_decorator(): - @register_dataclass - @dataclasses.dataclass - class _Decorated: - x: float - - key = f"{_Decorated.__module__}.{_Decorated.__qualname__}" - assert _DATACLASS_REGISTRY[key] is _Decorated - - -def test_register_non_dataclass_raises(): - with pytest.raises(TypeError, match="not a dataclass"): - register_dataclass(int) - - -def test_sentinel_large_string(): - t = pa.struct([pa.field("__dataclass.", pa.large_string()), pa.field("a", pa.int64())]) - assert has_dataclass_type_sentinel(t) is True - - -def test_sentinel_string_compat(): - # older Arrow versions wrote pa.string() instead of pa.large_string() - t = pa.struct([pa.field("__dataclass.", pa.string()), pa.field("a", pa.int64())]) - assert has_dataclass_type_sentinel(t) is True - - -def test_sentinel_missing_field(): - t = pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.large_string())]) - assert has_dataclass_type_sentinel(t) is False - - -def test_sentinel_non_struct(): - assert has_dataclass_type_sentinel(pa.int64()) is False - - -def test_struct_type_basic_fields(): - @dataclasses.dataclass - class _Point: - x: int - y: float - - converter = UniversalTypeConverter() - result = dataclass_to_arrow_struct_type(_Point, converter) - - assert pa.types.is_struct(result) - # __dataclass. must be the first field - assert result[0].name == "__dataclass." - assert result[0].type == pa.large_string() - assert result.field("x").type == pa.int64() - assert result.field("y").type == pa.float64() - - -def test_struct_type_string_field(): - @dataclasses.dataclass - class _Named: - name: str - - converter = UniversalTypeConverter() - result = dataclass_to_arrow_struct_type(_Named, converter) - assert result.field("name").type == pa.large_string() - - -def test_struct_type_non_dataclass_raises(): - converter = UniversalTypeConverter() - with pytest.raises(TypeError, match="not a dataclass"): - dataclass_to_arrow_struct_type(int, converter) - - -def _build_field_converters(cls: type, converter: UniversalTypeConverter) -> dict: - """Helper: build per-field Python-to-Arrow converters for a dataclass.""" - hints = typing.get_type_hints(cls) - return { - f.name: converter.get_python_to_arrow_converter(hints[f.name]) - for f in dataclasses.fields(cls) - } - - -def test_struct_dict_simple(): - @dataclasses.dataclass - class _Box: - width: int - label: str - - converter = UniversalTypeConverter() - field_converters = _build_field_converters(_Box, converter) - obj = _Box(width=10, label="big") - result = dataclass_to_struct_dict(obj, field_converters) - - fqcn = f"{_Box.__module__}.{_Box.__qualname__}" - assert result[DATACLASS_TYPE_FIELD] == f"dataclass:{fqcn}" - assert result["width"] == 10 - assert result["label"] == "big" - - -def test_struct_dict_type_error_on_class(): - with pytest.raises(TypeError, match="not a dataclass instance"): - dataclass_to_struct_dict(_Simple, {}) - - -def test_struct_dict_type_error_on_non_dataclass(): - with pytest.raises(TypeError, match="not a dataclass instance"): - dataclass_to_struct_dict(42, {}) - - -@dataclasses.dataclass -class _TierOne: - value: int - - -def test_tier1_import(): - """Tier 1: class is importable via importlib.""" - fqcn = f"{_TierOne.__module__}.{_TierOne.__qualname__}" - struct_dict = { - "__dataclass.": f"dataclass:{fqcn}", - "value": 7, - } - field_converters = {"value": lambda v: v} - cache: dict = {} - - # Patch importlib so tier 1 returns _TierOne - module_path, _, class_attr = fqcn.rpartition(".") - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - mock_mod = MagicMock() - setattr(mock_mod, class_attr, _TierOne) - mock_import.return_value = mock_mod - - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert isinstance(result, _TierOne) - assert result.value == 7 - # Cache should be populated - assert cache[fqcn] is _TierOne - - -def test_tier1_cache_hit(): - """Tier 1: cache hit skips importlib entirely.""" - fqcn = "some.module.SomeClass" - cache = {fqcn: _TierOne} - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "value": 3} - field_converters = {"value": lambda v: v} - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - mock_import.assert_not_called() - - assert isinstance(result, _TierOne) - assert result.value == 3 - - -def test_tier2_registry(monkeypatch): - """Tier 2: importlib fails, class found in registry.""" - @dataclasses.dataclass - class _RegClass: - score: float - - fqcn = "fake.module.RegClass" - monkeypatch.setitem(_DATACLASS_REGISTRY, fqcn, _RegClass) - - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "score": 9.5} - field_converters = {"score": lambda v: v} - cache: dict = {} - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module", side_effect=ImportError("no module")): - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert isinstance(result, _RegClass) - assert result.score == 9.5 - assert cache[fqcn] is _RegClass - - -def test_tier3_synthesize(): - """Tier 3: neither importable nor registered — synthesize a dataclass.""" - fqcn = "totally.unknown.Ghost" - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "name": "phantom", "age": 99} - field_converters = {"name": lambda v: v, "age": lambda v: v} - cache: dict = {} - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module", side_effect=ImportError("no module")): - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert dataclasses.is_dataclass(result) - assert result.name == "phantom" # type: ignore[attr-defined] - assert result.age == 99 # type: ignore[attr-defined] - # Synthesized class cached under fqcn for future rows - assert fqcn in cache - - -def test_missing_type_field_tier3(): - """Struct without __type falls through to tier 3 silently.""" - struct_dict = {"value": 42} - field_converters = {"value": lambda v: v} - cache: dict = {} - - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert dataclasses.is_dataclass(result) - assert result.value == 42 # type: ignore[attr-defined] - # No cache entry — no valid fqcn to cache under - assert len(cache) == 0 - - -def test_malformed_type_field_tier3(): - """Invalid __dataclass. format (fails regex) falls through to tier 3.""" - struct_dict = {"__dataclass.": "not-valid!!!", "x": 1} - field_converters = {"x": lambda v: v} - cache: dict = {} - - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert dataclasses.is_dataclass(result) - assert result.x == 1 # type: ignore[attr-defined] - - -def test_utc_simple_round_trip(): - """Full encode->decode round-trip through UniversalTypeConverter.""" - @dataclasses.dataclass - class _Color: - r: int - g: int - b: int - - converter = UniversalTypeConverter() - arrow_type = converter.python_type_to_arrow_type(_Color) - assert has_dataclass_type_sentinel(arrow_type) - - obj = _Color(r=255, g=128, b=0) - encode = converter.get_python_to_arrow_converter(_Color) - encoded = encode(obj) - assert encoded["__dataclass."] == f"dataclass:{_Color.__module__}.{_Color.__qualname__}" - - decode = converter.get_arrow_to_python_converter(arrow_type) - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - mock_mod = MagicMock() - setattr(mock_mod, "_Color", _Color) - mock_import.return_value = mock_mod - result = decode(encoded) - - assert isinstance(result, _Color) - assert result.r == 255 and result.g == 128 and result.b == 0 - - -def test_utc_nested_round_trip(): - """Nested dataclass encodes and decodes recursively.""" - @dataclasses.dataclass - class _Inner: - y: float - - @dataclasses.dataclass - class _Outer: - x: int - inner: _Inner - - converter = UniversalTypeConverter() - arrow_type = converter.python_type_to_arrow_type(_Outer) - - # Nested struct: inner field should itself be a __type-bearing struct - inner_arrow = arrow_type.field("inner").type - assert has_dataclass_type_sentinel(inner_arrow) - - obj = _Outer(x=1, inner=_Inner(y=3.14)) - encode = converter.get_python_to_arrow_converter(_Outer) - encoded = encode(obj) - - assert encoded["inner"]["__dataclass."] == f"dataclass:{_Inner.__module__}.{_Inner.__qualname__}" - assert encoded["inner"]["y"] == 3.14 - - decode = converter.get_arrow_to_python_converter(arrow_type) - - inner_fqcn = f"{_Inner.__module__}.{_Inner.__qualname__}" - outer_fqcn = f"{_Outer.__module__}.{_Outer.__qualname__}" - inner_attr = inner_fqcn.rpartition(".")[2] - outer_attr = outer_fqcn.rpartition(".")[2] - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - def fake_import(module_path): - mod = MagicMock() - setattr(mod, inner_attr, _Inner) - setattr(mod, outer_attr, _Outer) - return mod - mock_import.side_effect = fake_import - result = decode(encoded) - - assert isinstance(result, _Outer) - assert result.x == 1 - assert isinstance(result.inner, _Inner) - assert result.inner.y == 3.14 - - -def test_utc_clear_cache_clears_dataclass_cache(): - """clear_cache() also clears the per-instance dataclass lookup cache.""" - converter = UniversalTypeConverter() - - @dataclasses.dataclass - class _Temp: - n: int - - fqcn = f"{_Temp.__module__}.{_Temp.__qualname__}" - converter._dataclass_lookup_cache[fqcn] = _Temp - converter.clear_cache() - assert fqcn not in converter._dataclass_lookup_cache - - -def test_polymorphic_decode(): - """Two rows with different __type values each decode to their own class.""" - @dataclasses.dataclass - class _Cat: - name: str - - @dataclasses.dataclass - class _Dog: - name: str - - cat_fqcn = f"{_Cat.__module__}.{_Cat.__qualname__}" - dog_fqcn = f"{_Dog.__module__}.{_Dog.__qualname__}" - - # Both have the same Arrow schema (name: large_string) plus __dataclass. - arrow_type = pa.struct([ - pa.field("__dataclass.", pa.large_string()), - pa.field("name", pa.large_string()), - ]) - converter = UniversalTypeConverter() - decode = converter.get_arrow_to_python_converter(arrow_type) - - cat_attr = cat_fqcn.rpartition(".")[2] - dog_attr = dog_fqcn.rpartition(".")[2] - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - def fake_import(module_path): - mod = MagicMock() - setattr(mod, cat_attr, _Cat) - setattr(mod, dog_attr, _Dog) - return mod - mock_import.side_effect = fake_import - - row0 = decode({"__dataclass.": f"dataclass:{cat_fqcn}", "name": "Whiskers"}) - row1 = decode({"__dataclass.": f"dataclass:{dog_fqcn}", "name": "Rex"}) - - assert isinstance(row0, _Cat) and row0.name == "Whiskers" - assert isinstance(row1, _Dog) and row1.name == "Rex" - - -@pytest.mark.integration -def test_parquet_round_trip(): - """Full round-trip: python_dicts_to_arrow_table -> Parquet -> arrow_table_to_python_dicts.""" - import pyarrow.parquet as pq - - @dataclasses.dataclass - class _Record: - score: float - label: str - - converter = UniversalTypeConverter() - - python_dicts = [ - {"rec": _Record(score=0.9, label="good")}, - {"rec": _Record(score=0.1, label="bad")}, - ] - python_schema = Schema({"rec": _Record}) - table = converter.python_dicts_to_arrow_table(python_dicts, python_schema=python_schema) - - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "test.parquet") - pq.write_table(table, path) - loaded = pq.read_table(path) - - rec_fqcn = f"{_Record.__module__}.{_Record.__qualname__}" - rec_attr = rec_fqcn.rpartition(".")[2] - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - mod = MagicMock() - setattr(mod, rec_attr, _Record) - mock_import.return_value = mod - results = converter.arrow_table_to_python_dicts(loaded) - - assert len(results) == 2 - assert isinstance(results[0]["rec"], _Record) - assert results[0]["rec"].score == 0.9 - assert results[0]["rec"].label == "good" - assert isinstance(results[1]["rec"], _Record) - assert results[1]["rec"].score == 0.1 - assert results[1]["rec"].label == "bad" - - -# --------------------------------------------------------------------------- -# init=False field exclusion -# --------------------------------------------------------------------------- - - -def test_struct_type_excludes_init_false_fields(): - """dataclass_to_arrow_struct_type must not include fields with init=False.""" - @dataclasses.dataclass - class _WithComputed: - value: int - cached: str = dataclasses.field(init=False, default="") - - def __post_init__(self) -> None: - self.cached = f"v={self.value}" - - converter = UniversalTypeConverter() - result = dataclass_to_arrow_struct_type(_WithComputed, converter) - - field_names = [result.field(i).name for i in range(result.num_fields)] - assert "__dataclass." in field_names - assert "value" in field_names - assert "cached" not in field_names, "init=False field must be excluded from Arrow schema" - - -def test_struct_dict_excludes_init_false_fields(): - """dataclass_to_struct_dict must not include fields with init=False.""" - @dataclasses.dataclass - class _WithComputed: - value: int - cached: str = dataclasses.field(init=False, default="") - - def __post_init__(self) -> None: - self.cached = f"v={self.value}" - - obj = _WithComputed(value=42) - result = dataclass_to_struct_dict(obj, {}) - - assert "value" in result - assert "cached" not in result, "init=False field must be excluded from encoded dict" - - -def test_utc_converter_excludes_init_false_fields(): - """UniversalTypeConverter converter closure must not include init=False fields.""" - @dataclasses.dataclass - class _WithComputed: - x: int - derived: str = dataclasses.field(init=False, default="") - - def __post_init__(self) -> None: - self.derived = str(self.x * 2) - - converter = UniversalTypeConverter() - encode = converter.get_python_to_arrow_converter(_WithComputed) - encoded = encode(_WithComputed(x=7)) - - assert "x" in encoded - assert "derived" not in encoded, "init=False field must not appear in encoded output" - - -def test_init_false_round_trip(): - """Full round-trip: init=False field is excluded from Arrow and reconstructed post-init.""" - @dataclasses.dataclass - class _Computed: - n: int - doubled: int = dataclasses.field(init=False) - - def __post_init__(self) -> None: - self.doubled = self.n * 2 - - converter = UniversalTypeConverter() - arrow_type = converter.python_type_to_arrow_type(_Computed) - - # Arrow schema must not contain 'doubled' - field_names = [arrow_type.field(i).name for i in range(arrow_type.num_fields)] - assert "doubled" not in field_names - - obj = _Computed(n=5) - encode = converter.get_python_to_arrow_converter(_Computed) - encoded = encode(obj) - assert "doubled" not in encoded - - decode = converter.get_arrow_to_python_converter(arrow_type) - fqcn = f"{_Computed.__module__}.{_Computed.__qualname__}" - attr = fqcn.rpartition(".")[2] - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as m: - mod = MagicMock() - setattr(mod, attr, _Computed) - m.return_value = mod - result = decode(encoded) - - assert isinstance(result, _Computed) - assert result.n == 5 - # __post_init__ recomputes doubled - assert result.doubled == 10 - - -# --------------------------------------------------------------------------- -# Extra-field / superset-schema kwargs filtering in decoder -# --------------------------------------------------------------------------- - - -def test_decoder_extra_null_field_no_warning(caplog): - """A NULL extra field (schema evolution — column present but empty for this row) - is silently dropped without a warning.""" - @dataclasses.dataclass - class _Narrow: - name: str - - fqcn = f"{_Narrow.__module__}.{_Narrow.__qualname__}" - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "name": "Alice", "age": None} - field_converters = {"name": lambda v: v, "age": lambda v: v} - cache: dict = {} - - attr = fqcn.rpartition(".")[2] - import logging - with caplog.at_level(logging.WARNING, logger="orcapod.semantic_types.dataclass_encoding"): - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as m: - mod = MagicMock() - setattr(mod, attr, _Narrow) - m.return_value = mod - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert isinstance(result, _Narrow) - assert result.name == "Alice" - # No warning for a null extra field - assert not any("age" in r.message for r in caplog.records) - - -def test_decoder_extra_nonnull_field_warns(caplog): - """A non-null extra field being discarded must emit a WARNING — it signals a - schema mismatch or encoding bug, not normal schema evolution.""" - @dataclasses.dataclass - class _Narrow: - name: str - - fqcn = f"{_Narrow.__module__}.{_Narrow.__qualname__}" - # 'age' is non-null: real data being silently dropped is a bug signal - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "name": "Alice", "age": 30} - field_converters = {"name": lambda v: v, "age": lambda v: v} - cache: dict = {} - - attr = fqcn.rpartition(".")[2] - import logging - with caplog.at_level(logging.WARNING, logger="orcapod.semantic_types.dataclass_encoding"): - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as m: - mod = MagicMock() - setattr(mod, attr, _Narrow) - m.return_value = mod - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - - assert isinstance(result, _Narrow) - assert result.name == "Alice" - assert not hasattr(result, "age") - # Must emit a warning mentioning the dropped field - assert any("age" in r.message and r.levelno == logging.WARNING for r in caplog.records) - - -# --------------------------------------------------------------------------- -# Tier-1 import gate (_TIER1_IMPORT_ENABLED) -# --------------------------------------------------------------------------- - - -def test_tier1_disabled_skips_to_tier2(monkeypatch): - """When _TIER1_IMPORT_ENABLED is False, tier-1 import is skipped and tier-2 is used.""" - @dataclasses.dataclass - class _GatedClass: - val: int - - fqcn = "some.module.GatedClass" - monkeypatch.setitem(_DATACLASS_REGISTRY, fqcn, _GatedClass) - monkeypatch.setattr(_dc_enc, "_TIER1_IMPORT_ENABLED", False) - - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "val": 99} - field_converters = {"val": lambda v: v} - cache: dict = {} - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - mock_import.assert_not_called() - - assert isinstance(result, _GatedClass) - assert result.val == 99 - - -def test_tier1_disabled_falls_to_tier3(monkeypatch): - """When _TIER1_IMPORT_ENABLED is False and class is unregistered, tier-3 synthesizes.""" - monkeypatch.setattr(_dc_enc, "_TIER1_IMPORT_ENABLED", False) - - fqcn = "totally.absent.UnknownClass" - struct_dict = {"__dataclass.": f"dataclass:{fqcn}", "score": 7.5} - field_converters = {"score": lambda v: v} - cache: dict = {} - - with patch("orcapod.semantic_types.dataclass_encoding.importlib.import_module") as mock_import: - result = struct_dict_to_dataclass(struct_dict, field_converters, cache) - mock_import.assert_not_called() - - assert dataclasses.is_dataclass(result) - assert result.score == 7.5 # type: ignore[attr-defined] - - -# --------------------------------------------------------------------------- -# arrow_schema_to_python_schema for dataclass structs (Item 4 fix) -# --------------------------------------------------------------------------- - - -def test_arrow_schema_to_python_schema_dataclass_returns_concrete_type(): - """arrow_schema_to_python_schema returns a concrete dataclass type for sentinel structs. - - After the fix, converting a dataclass struct Arrow type back to a Python - schema must return a proper @dataclass type rather than typing.Any, so - that python_schema_to_arrow_schema can complete the round-trip. - """ - @dataclasses.dataclass - class _Point: - x: int - y: float - - converter = UniversalTypeConverter() - arrow_type = converter.python_type_to_arrow_type(_Point) - assert has_dataclass_type_sentinel(arrow_type) - - # Build a one-field Arrow schema wrapping the struct - arrow_schema = pa.schema([pa.field("point", arrow_type, nullable=False)]) - python_schema = converter.arrow_schema_to_python_schema(arrow_schema) - - python_type = python_schema["point"] - assert dataclasses.is_dataclass(python_type), ( - f"Expected a dataclass type, got {python_type!r}" - ) - field_names = {f.name for f in dataclasses.fields(python_type)} - assert "x" in field_names - assert "y" in field_names - assert DATACLASS_TYPE_FIELD not in field_names, ( - "Sentinel field must not appear among the synthesized dataclass fields" - ) - - -def test_arrow_schema_to_python_schema_dataclass_round_trip(): - """python_schema → arrow_schema → python_schema is lossless for dataclass fields. - - After the fix, the synthesized dataclass type is itself a proper dataclass, - so python_schema_to_arrow_schema can convert it back to the original struct. - """ - @dataclasses.dataclass - class _Box: - width: int - label: str - - converter = UniversalTypeConverter() - original_arrow = converter.python_type_to_arrow_type(_Box) - - # Round-trip via python schema - schema = pa.schema([pa.field("box", original_arrow, nullable=False)]) - python_schema = converter.arrow_schema_to_python_schema(schema) - synthesized_type = python_schema["box"] - - assert dataclasses.is_dataclass(synthesized_type) - # Convert the synthesized type back to Arrow — must produce the same struct - recovered_arrow = converter.python_type_to_arrow_type(synthesized_type) - assert has_dataclass_type_sentinel(recovered_arrow) - assert recovered_arrow.field("width").type == pa.int64() - assert recovered_arrow.field("label").type == pa.large_string() - - -def test_arrow_schema_to_python_schema_dataclass_nullable_fields(): - """Nullable struct fields produce Optional[T] annotations in the synthesized dataclass. - - Regression guard for the nullability fix: when a dataclass-sentinel struct has a - nullable field, the synthesized Python dataclass must annotate it as ``Optional[T]`` - so that: - - The type correctly conveys that None is a valid value. - - Round-trips through ``python_schema_to_arrow_schema`` preserve ``nullable=True`` - (because ``Optional[T]`` triggers ``_is_optional_type``). - - Non-nullable fields must remain plain ``T`` (not Optional). - """ - converter = UniversalTypeConverter() - - # Build a raw dataclass-sentinel struct type manually with mixed nullability. - import pyarrow as _pa - struct_type = _pa.struct([ - _pa.field(DATACLASS_TYPE_FIELD, _pa.large_string()), # sentinel (excluded) - _pa.field("required_field", _pa.int64(), nullable=False), - _pa.field("optional_field", _pa.int64(), nullable=True), - ]) - assert has_dataclass_type_sentinel(struct_type) - - synthesized = converter.arrow_type_to_python_type(struct_type) - assert dataclasses.is_dataclass(synthesized) - - field_map = {f.name: f.type for f in dataclasses.fields(synthesized)} - - import types as _types - import typing as _typing - - def _is_union_with_none(t: object) -> bool: - """Return True for both T | None (types.UnionType) and Optional[T] (typing.Union).""" - return isinstance(t, _types.UnionType) or _typing.get_origin(t) is _typing.Union - - # Non-nullable field must be plain int (or equivalent), not a union-with-None. - required_type = field_map["required_field"] - assert not _is_union_with_none(required_type), ( - "required_field (nullable=False) must not be T | None" - ) - - # Nullable field must be T | None (or Optional[T]). - optional_type = field_map["optional_field"] - assert _is_union_with_none(optional_type), ( - "optional_field (nullable=True) must be T | None" - ) - non_none_args = [a for a in _typing.get_args(optional_type) if a is not type(None)] - assert len(non_none_args) == 1, "T | None must wrap exactly one non-None type" - - # Sentinel must not appear in the synthesized dataclass fields. - assert DATACLASS_TYPE_FIELD not in field_map - - -def test_two_distinct_dataclass_columns_no_collision(): - """Two dataclass columns with different schemas are synthesized as distinct types. - - Regression test for the hash-based naming fix: when an Arrow schema contains - two struct columns that both have the dataclass sentinel but different fields, - ``arrow_schema_to_python_schema`` must return two *different* Python types — - one per column — rather than returning the same cached class for both. - """ - @dataclasses.dataclass - class _Alpha: - x: int - y: float - - @dataclasses.dataclass - class _Beta: - name: str - count: int - - converter = UniversalTypeConverter() - alpha_arrow = converter.python_type_to_arrow_type(_Alpha) - beta_arrow = converter.python_type_to_arrow_type(_Beta) - - # Both columns carry the dataclass sentinel. - assert has_dataclass_type_sentinel(alpha_arrow) - assert has_dataclass_type_sentinel(beta_arrow) - - # Place both in the same Arrow schema (simulating two dataclass columns in one table). - schema = pa.schema([ - pa.field("col_a", alpha_arrow, nullable=False), - pa.field("col_b", beta_arrow, nullable=False), - ]) - python_schema = converter.arrow_schema_to_python_schema(schema) - - type_a = python_schema["col_a"] - type_b = python_schema["col_b"] - - # Both must be synthesized dataclasses … - assert dataclasses.is_dataclass(type_a), f"col_a type is not a dataclass: {type_a!r}" - assert dataclasses.is_dataclass(type_b), f"col_b type is not a dataclass: {type_b!r}" - - # … but they must be *different* types (no name collision in the lookup cache). - assert type_a is not type_b, ( - "Both dataclass columns resolved to the same synthesized class — " - "hash-based naming is required to prevent this collision." - ) - - # Verify that the field sets are correct for each synthesized type. - fields_a = {f.name for f in dataclasses.fields(type_a)} - fields_b = {f.name for f in dataclasses.fields(type_b)} - assert fields_a == {"x", "y"} - assert fields_b == {"name", "count"} - # Sentinel must not leak into either synthesized type. - assert DATACLASS_TYPE_FIELD not in fields_a - assert DATACLASS_TYPE_FIELD not in fields_b diff --git a/tests/test_semantic_types/test_path_struct_converter.py b/tests/test_semantic_types/test_path_struct_converter.py deleted file mode 100644 index 740b0c16..00000000 --- a/tests/test_semantic_types/test_path_struct_converter.py +++ /dev/null @@ -1,132 +0,0 @@ -from pathlib import Path -from typing import cast - -import pytest - -from orcapod.hashing.file_hashers import BasicFileHasher -from orcapod.semantic_types.semantic_struct_converters import PythonPathStructConverter - - -@pytest.fixture -def file_hasher(): - return BasicFileHasher(algorithm="sha256") - - -@pytest.fixture -def converter(file_hasher): - return PythonPathStructConverter(file_hasher=file_hasher) - - -def test_path_to_struct_and_back(converter): - path_obj = Path("/tmp/test.txt") - struct_dict = converter.python_to_struct_dict(path_obj) - assert struct_dict["path"] == str(path_obj) - restored = converter.struct_dict_to_python(struct_dict) - assert restored == path_obj - - -def test_path_to_struct_invalid_type(converter): - with pytest.raises(TypeError): - converter.python_to_struct_dict("not_a_path") # type: ignore - - -def test_struct_to_python_missing_field(converter): - with pytest.raises(ValueError): - converter.struct_dict_to_python({}) - - -def test_can_handle_python_type(converter): - assert converter.can_handle_python_type(Path) - assert not converter.can_handle_python_type(str) - - -def test_can_handle_struct_type(converter): - struct_type = converter.arrow_struct_type - assert converter.can_handle_struct_type(struct_type) - - # Should fail for wrong fields - class FakeField: - def __init__(self, name, type): - self.name = name - self.type = type - - class FakeStructType(list): - @property - def names(self): - return [f.name for f in self] - - pass - - import pyarrow as pa - - fake_struct = cast( - pa.StructType, FakeStructType([FakeField("wrong", struct_type[0].type)]) - ) - assert not converter.can_handle_struct_type(fake_struct) - - -def test_is_semantic_struct(converter): - assert converter.is_semantic_struct({"path": "/tmp/test.txt"}) - assert not converter.is_semantic_struct({"not_path": "value"}) - assert not converter.is_semantic_struct({"path": 123}) - - -def test_hash_struct_dict_file_not_found(converter, tmp_path): - struct_dict = {"path": str(tmp_path / "does_not_exist.txt")} - with pytest.raises(FileNotFoundError): - converter.hash_struct_dict(struct_dict) - - -def test_hash_struct_dict_is_directory(converter, tmp_path): - struct_dict = {"path": str(tmp_path)} - with pytest.raises(IsADirectoryError): - converter.hash_struct_dict(struct_dict) - - -def test_hash_struct_dict_content_based(converter, tmp_path): - """Two distinct files with identical content produce the same hash.""" - file1 = tmp_path / "file1.txt" - file2 = tmp_path / "file2.txt" - content = "identical content" - file1.write_text(content) - file2.write_text(content) - hash1 = converter.hash_struct_dict({"path": str(file1)}) - hash2 = converter.hash_struct_dict({"path": str(file2)}) - assert hash1 == hash2 - - -def test_hash_path_objects_content_based(converter, tmp_path): - """Round-trip through python_to_struct_dict then hash_struct_dict.""" - file1 = tmp_path / "fileA.txt" - file2 = tmp_path / "fileB.txt" - content = "same file content" - file1.write_text(content) - file2.write_text(content) - struct_dict1 = converter.python_to_struct_dict(Path(file1)) - struct_dict2 = converter.python_to_struct_dict(Path(file2)) - hash1 = converter.hash_struct_dict(struct_dict1) - hash2 = converter.hash_struct_dict(struct_dict2) - assert hash1 == hash2 - - -def test_hash_struct_dict_with_prefix(converter, tmp_path): - """Hash always starts with 'path:sha256:'.""" - file = tmp_path / "file.txt" - file.write_text("hello") - hash_str = converter.hash_struct_dict({"path": str(file)}) - assert hash_str.startswith("path:sha256:") - - -def test_hash_struct_dict_different_content(converter, tmp_path): - """Same path with modified content yields a different hash.""" - file = tmp_path / "mutable.txt" - file.write_text("version 1") - hash1 = converter.hash_struct_dict({"path": str(file)}) - file.write_text("version 2") - hash2 = converter.hash_struct_dict({"path": str(file)}) - assert hash1 != hash2 - - -def test_hash_struct_dict_missing_path_field(converter): - with pytest.raises(ValueError, match="Missing 'path' field"): - converter.hash_struct_dict({}) diff --git a/tests/test_semantic_types/test_pydata_utils.py b/tests/test_semantic_types/test_pydata_utils.py deleted file mode 100644 index d9716866..00000000 --- a/tests/test_semantic_types/test_pydata_utils.py +++ /dev/null @@ -1,136 +0,0 @@ -from pathlib import Path, PosixPath -from typing import Any - -import pytest - -from orcapod.semantic_types import pydata_utils - - -def test_pylist_to_pydict_typical(): - data = [{"a": 1, "b": 2}, {"a": 3, "c": 4}] - result = pydata_utils.pylist_to_pydict(data) - assert result == {"a": [1, 3], "b": [2, None], "c": [None, 4]} - - -def test_pylist_to_pydict_missing_keys(): - data = [{"a": 1}, {"b": 2}, {"a": 3, "b": 4}] - result = pydata_utils.pylist_to_pydict(data) - assert result == {"a": [1, None, 3], "b": [None, 2, 4]} - - -def test_pylist_to_pydict_empty(): - assert pydata_utils.pylist_to_pydict([]) == {} - - -def test_pylist_to_pydict_empty_dicts(): - data = [{}, {}, {}] - assert pydata_utils.pylist_to_pydict(data) == {} - - -def test_pydict_to_pylist_typical(): - data = {"a": [1, 3], "b": [2, None], "c": [None, 4]} - result = pydata_utils.pydict_to_pylist(data) - assert result == [{"a": 1, "b": 2, "c": None}, {"a": 3, "b": None, "c": 4}] - - -def test_pydict_to_pylist_uneven_lengths(): - data = {"a": [1, 2], "b": [3]} - with pytest.raises(ValueError): - pydata_utils.pydict_to_pylist(data) - - -def test_pydict_to_pylist_empty(): - assert pydata_utils.pydict_to_pylist({}) == [] - - -def test_pydict_to_pylist_empty_lists(): - data = {"a": [], "b": []} - assert pydata_utils.pydict_to_pylist(data) == [] - - -def test_infer_python_schema_from_pylist_data_typical(): - data = [{"a": 1, "b": 2.0}, {"a": 3, "b": None}] - schema = pydata_utils.infer_python_schema_from_pylist_data(data) - assert schema["a"] in (int, int | None) - assert schema["b"] in (float | None, float) - - -def test_infer_python_schema_from_pylist_data_complex(): - data = [ - {"path": Path("/tmp/file1"), "size": 123}, - {"path": Path("/tmp/file2"), "size": None}, - ] - schema = pydata_utils.infer_python_schema_from_pylist_data(data) - assert schema["path"] in (Path, PosixPath) - assert schema["size"] == int | None - - -def test_infer_python_schema_from_pylist_data_empty(): - assert pydata_utils.infer_python_schema_from_pylist_data([]) == {} - - -def test_infer_python_schema_from_pylist_data_mixed_types(): - data = [{"a": 1}, {"a": "x"}, {"a": 2.5}] - schema = pydata_utils.infer_python_schema_from_pylist_data(data) - # Should be Union[int, float, str] or Any - assert "a" in schema - - -def test_infer_python_schema_from_pydict_data_typical(): - data = {"a": [1, 2], "b": [None, 3.5]} - schema = pydata_utils.infer_python_schema_from_pydict_data(data) - assert schema["a"] in (int, int | None) - assert schema["b"] in (float | None, float) - - -def test_infer_python_schema_from_pydict_data_empty(): - assert pydata_utils.infer_python_schema_from_pydict_data({}) == {} - - -def test_infer_python_schema_from_pydict_data_empty_lists(): - data = {"a": [], "b": []} - schema = pydata_utils.infer_python_schema_from_pydict_data(data) - assert schema["a"] == str | None - assert schema["b"] == str | None - - -def test_infer_python_schema_from_pydict_data_mixed_types(): - data = {"a": [1, "x", 2.5]} - schema = pydata_utils.infer_python_schema_from_pydict_data(data) - assert "a" in schema - - -def test_round_trip_pylist_pydict(): - data = [{"a": 1, "b": 2}, {"a": 3, "c": 4}] - pydict = pydata_utils.pylist_to_pydict(data) - pylist = pydata_utils.pydict_to_pylist(pydict) - # Should be equivalent to original data (order of keys may differ) - for orig, roundtrip in zip(data, pylist): - # Compare dicts for value equality, ignoring key order and missing keys - for k in orig: - assert orig[k] == roundtrip[k] - - -def test_round_trip_pydict_pylist(): - data = {"a": [1, 3], "b": [2, None], "c": [None, 4]} - pylist = pydata_utils.pydict_to_pylist(data) - pydict = pydata_utils.pylist_to_pydict(pylist) - for k in data: - assert pydict[k] == data[k] - - -# --------------------------------------------------------------------------- -# ENG-389: empty container inference produces list[Any] / dict[Any, Any] -# --------------------------------------------------------------------------- - - -def test_infer_empty_list_schema(): - """A field whose only value is [] infers as list[Any].""" - schema = pydata_utils.infer_python_schema_from_pylist_data([{"items": []}]) - assert schema["items"] == list[Any] - - -def test_infer_empty_dict_schema(): - """A field whose only value is {} infers as dict[Any, Any].""" - schema = pydata_utils.infer_python_schema_from_pylist_data([{"meta": {}}]) - assert schema["meta"] == dict[Any, Any] diff --git a/tests/test_semantic_types/test_schema_arrow_equality.py b/tests/test_semantic_types/test_schema_arrow_equality.py deleted file mode 100644 index d004e188..00000000 --- a/tests/test_semantic_types/test_schema_arrow_equality.py +++ /dev/null @@ -1,323 +0,0 @@ -""" -Tests verifying Schema ↔ Arrow logical equality (PLT-923). - -Coverage --------- -- Python-equal schemas produce logically equal Arrow schemas -- Python-unequal schemas produce logically unequal Arrow schemas -- Field insertion order does not affect logical equality -- Nullability correspondence: T | None → nullable=True, plain T → nullable=False -- Round-trip: python_schema_to_arrow_schema ∘ arrow_schema_to_python_schema is lossless -- Nested/complex types maintain the correspondence -- Schema.as_required() strips optional_fields for Arrow-level comparison - -"Logical equality" is determined by StarfixArrowHasher.hash_schema digest equality: -column-order-independent, Utf8/LargeUtf8 and Binary/LargeBinary normalised, -nullability-sensitive. -""" - -from __future__ import annotations - -from pathlib import Path - -import pyarrow as pa - -from orcapod.contexts import get_default_context -from orcapod.hashing.arrow_hashers import StarfixArrowHasher -from orcapod.semantic_types import SemanticTypeRegistry -from orcapod.types import Schema - -# --------------------------------------------------------------------------- -# Shared infrastructure -# --------------------------------------------------------------------------- - -# SemanticTypeRegistry is empty: hash_schema operates on Arrow types only and -# never consults the semantic registry (unlike hash_table). -_hasher = StarfixArrowHasher(SemanticTypeRegistry(), hasher_id="test") - - -def _to_arrow(schema: Schema) -> pa.Schema: - """Convert a Python Schema to an Arrow schema via the default context.""" - return get_default_context().type_converter.python_schema_to_arrow_schema(schema) - - -def _arrow_logical_eq(s1: pa.Schema, s2: pa.Schema) -> bool: - """Return True if two Arrow schemas are logically equal under the starfix hash.""" - return _hasher.hash_schema(s1).digest == _hasher.hash_schema(s2).digest - - -# --------------------------------------------------------------------------- -# Positive: equal Python schemas → logically equal Arrow schemas -# --------------------------------------------------------------------------- - - -class TestEqualSchemasHaveLogicallyEqualArrowSchemas: - def test_single_int_field(self): - s1 = Schema(a=int) - s2 = Schema(a=int) - assert s1 == s2 - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_single_float_field(self): - s1 = Schema(a=float) - s2 = Schema(a=float) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_single_str_field(self): - s1 = Schema(a=str) - s2 = Schema(a=str) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_single_bool_field(self): - s1 = Schema(a=bool) - s2 = Schema(a=bool) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_single_bytes_field(self): - s1 = Schema(a=bytes) - s2 = Schema(a=bytes) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_multiple_primitive_fields(self): - s1 = Schema({"a": int, "b": float, "c": str}) - s2 = Schema({"a": int, "b": float, "c": str}) - assert s1 == s2 - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_kwargs_vs_mapping_construction(self): - """Schema(a=int, b=str) must equal Schema({"a": int, "b": str}).""" - s_kwargs = Schema(a=int, b=str) - s_mapping = Schema({"a": int, "b": str}) - assert s_kwargs == s_mapping - assert _arrow_logical_eq(_to_arrow(s_kwargs), _to_arrow(s_mapping)) - - def test_empty_schema(self): - s1 = Schema.empty() - s2 = Schema({}) - assert s1 == s2 - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_schema_equals_plain_dict(self): - """Schema.__eq__ accepts plain Mapping; dict → Arrow conversion must match.""" - s = Schema({"x": int}) - d = {"x": int} - # Schema.__eq__ raises NotImplementedError for non-Mapping non-Schema; plain - # dict is a Mapping so this should work. - assert s == d - assert _arrow_logical_eq( - _to_arrow(s), - get_default_context().type_converter.python_schema_to_arrow_schema(d), - ) - - -# --------------------------------------------------------------------------- -# Negative: unequal Python schemas → logically unequal Arrow schemas -# --------------------------------------------------------------------------- - - -class TestUnequalSchemasHaveLogicallyUnequalArrowSchemas: - def test_different_field_names(self): - s1 = Schema(a=int) - s2 = Schema(b=int) - assert s1 != s2 - assert not _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_different_field_types(self): - s1 = Schema(a=int) - s2 = Schema(a=float) - assert s1 != s2 - assert not _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_subset_schema_differs(self): - s1 = Schema({"a": int, "b": str}) - s2 = Schema({"a": int}) - assert s1 != s2 - assert not _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - -# --------------------------------------------------------------------------- -# Field ordering -# --------------------------------------------------------------------------- - - -class TestFieldOrderingDoesNotAffectLogicalEquality: - def test_two_fields_reversed_insertion_order(self): - """Both Python equality and Arrow logical equality are order-insensitive.""" - s1 = Schema({"a": int, "b": str}) - s2 = Schema({"b": str, "a": int}) - assert s1 == s2 - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_three_fields_permuted_order(self): - s1 = Schema({"x": int, "y": float, "z": str}) - s2 = Schema({"z": str, "x": int, "y": float}) - assert s1 == s2 - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - -# --------------------------------------------------------------------------- -# Nullability correspondence -# --------------------------------------------------------------------------- - - -class TestNullabilityCorrespondence: - def test_plain_int_is_non_nullable(self): - arrow = _to_arrow(Schema(a=int)) - assert arrow.field("a").nullable is False - - def test_optional_int_is_nullable(self): - arrow = _to_arrow(Schema({"a": int | None})) - assert arrow.field("a").nullable is True - - def test_plain_primitives_all_non_nullable(self): - arrow = _to_arrow(Schema({"a": str, "b": float, "c": bool, "d": bytes})) - for name in ("a", "b", "c", "d"): - assert arrow.field(name).nullable is False, ( - f"Expected {name} to be non-nullable" - ) - - def test_optional_primitives_all_nullable(self): - arrow = _to_arrow(Schema({"a": str | None, "b": float | None})) - assert arrow.field("a").nullable is True - assert arrow.field("b").nullable is True - - def test_int_and_optional_int_are_python_unequal(self): - assert Schema(a=int) != Schema({"a": int | None}) - - def test_int_and_optional_int_are_arrow_logically_unequal(self): - s_plain = Schema(a=int) - s_optional = Schema({"a": int | None}) - assert not _arrow_logical_eq(_to_arrow(s_plain), _to_arrow(s_optional)) - - -# --------------------------------------------------------------------------- -# Round-trip: Python → Arrow → Python -# --------------------------------------------------------------------------- - - -class TestRoundTrip: - def _round_trip(self, schema: Schema) -> Schema: - converter = get_default_context().type_converter - return converter.arrow_schema_to_python_schema( - converter.python_schema_to_arrow_schema(schema) - ) - - def test_int_stays_int(self): - result = self._round_trip(Schema(a=int)) - assert result["a"] == int - - def test_optional_int_stays_optional_int(self): - result = self._round_trip(Schema({"a": int | None})) - assert result["a"] == int | None - - def test_plain_str_stays_str(self): - result = self._round_trip(Schema(a=str)) - assert result["a"] == str - - def test_optional_str_stays_optional_str(self): - result = self._round_trip(Schema({"a": str | None})) - assert result["a"] == str | None - - def test_plain_float_stays_float(self): - result = self._round_trip(Schema(a=float)) - assert result["a"] == float - - def test_plain_bool_stays_bool(self): - result = self._round_trip(Schema(a=bool)) - assert result["a"] == bool - - def test_plain_bytes_stays_bytes(self): - result = self._round_trip(Schema(a=bytes)) - assert result["a"] == bytes - - def test_optional_float_stays_optional_float(self): - result = self._round_trip(Schema({"a": float | None})) - assert result["a"] == float | None - - def test_mixed_nullable_and_non_nullable(self): - original = Schema({"req": int, "opt": str | None, "also_req": float}) - result = self._round_trip(original) - assert result["req"] == int - assert result["opt"] == str | None - assert result["also_req"] == float - - -# --------------------------------------------------------------------------- -# Nested and complex types -# --------------------------------------------------------------------------- - - -class TestNestedAndComplexTypes: - def test_list_int_is_non_nullable(self): - arrow = _to_arrow(Schema({"a": list[int]})) - assert arrow.field("a").nullable is False - - def test_list_str_is_non_nullable(self): - arrow = _to_arrow(Schema({"a": list[str]})) - assert arrow.field("a").nullable is False - - def test_optional_list_int_is_nullable(self): - arrow = _to_arrow(Schema({"a": list[int] | None})) - assert arrow.field("a").nullable is True - - def test_nested_list_is_non_nullable(self): - arrow = _to_arrow(Schema({"a": list[list[int]]})) - assert arrow.field("a").nullable is False - - def test_path_is_non_nullable(self): - """Path → Arrow struct {path: large_string}, nullable=False.""" - arrow = _to_arrow(Schema({"p": Path})) - assert arrow.field("p").nullable is False - assert pa.types.is_struct(arrow.field("p").type) - - def test_equal_list_schemas_are_logically_equal(self): - s1 = Schema({"items": list[int]}) - s2 = Schema({"items": list[int]}) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_list_int_and_list_str_are_logically_unequal(self): - s1 = Schema({"items": list[int]}) - s2 = Schema({"items": list[str]}) - assert not _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - -# --------------------------------------------------------------------------- -# Schema.as_required() -# --------------------------------------------------------------------------- - - -class TestAsRequired: - def test_as_required_equals_schema_without_optional_fields(self): - """Schema with optional_fields equals a Schema without after as_required().""" - s_with_optional = Schema({"a": int, "b": str}, optional_fields=["b"]) - s_without = Schema({"a": int, "b": str}) - assert s_with_optional.as_required() == s_without - - def test_as_required_on_schema_without_optional_is_noop(self): - """as_required() on a fully required schema is idempotent.""" - s = Schema({"a": int, "b": str}) - assert s.as_required() == s - - def test_as_required_idempotent(self): - """Calling as_required() twice gives the same result as once.""" - s = Schema({"a": int}, optional_fields=["a"]) - assert s.as_required().as_required() == s.as_required() - - def test_schemas_differing_only_in_optional_fields_are_python_unequal(self): - """Two schemas with the same fields but different optional_fields are unequal.""" - s1 = Schema({"a": int, "b": str}, optional_fields=["b"]) - s2 = Schema({"a": int, "b": str}) - assert s1 != s2 - - def test_schemas_differing_only_in_optional_fields_have_equal_arrow_schemas(self): - """optional_fields has no Arrow representation — Arrow schemas must be equal.""" - s1 = Schema({"a": int, "b": str}, optional_fields=["b"]) - s2 = Schema({"a": int, "b": str}) - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) - - def test_as_required_implies_arrow_logical_equality(self): - """If s1.as_required() == s2.as_required(), their Arrow schemas are logically equal.""" - s1 = Schema({"x": int, "y": float}, optional_fields=["x"]) - s2 = Schema({"x": int, "y": float}) - assert s1.as_required() == s2.as_required() - assert _arrow_logical_eq(_to_arrow(s1), _to_arrow(s2)) diff --git a/tests/test_semantic_types/test_semantic_registry.py b/tests/test_semantic_types/test_semantic_registry.py deleted file mode 100644 index 82df93e0..00000000 --- a/tests/test_semantic_types/test_semantic_registry.py +++ /dev/null @@ -1,239 +0,0 @@ -import uuid -from unittest.mock import Mock - -import pyarrow as pa -import pytest - -from orcapod.semantic_types import semantic_registry - - -def test_registry_initialization(): - registry = semantic_registry.SemanticTypeRegistry() - assert registry.list_semantic_types() == [] - assert registry.list_python_types() == [] - assert registry.list_struct_signatures() == [] - - -def test_register_and_retrieve_converter(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type = Mock(name="StructType") - converter = Mock() - converter.python_type = python_type - converter.arrow_struct_type = struct_type - registry.register_converter("mock_type", converter) - # Retrieve by semantic type name - assert registry.get_converter_for_semantic_type("mock_type") is converter - # Retrieve by python type - assert registry.get_converter_for_python_type(python_type) is converter - # Retrieve by struct signature - assert registry.get_converter_for_struct_signature(struct_type) is converter - - -def test_register_duplicate_semantic_type_raises(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type = Mock(name="StructType") - converter1 = Mock() - converter1.python_type = python_type - converter1.arrow_struct_type = struct_type - registry.register_converter("mock_type", converter1) - converter2 = Mock() - converter2.python_type = python_type - converter2.arrow_struct_type = struct_type - with pytest.raises(ValueError): - registry.register_converter("mock_type", converter2) - - -def test_register_conflicting_python_type_raises(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type1 = Mock(name="StructType1") - struct_type2 = Mock(name="StructType2") - converter1 = Mock() - converter1.python_type = python_type - converter1.arrow_struct_type = struct_type1 - registry.register_converter("mock_type1", converter1) - converter2 = Mock() - converter2.python_type = python_type - converter2.arrow_struct_type = struct_type2 - with pytest.raises(ValueError): - registry.register_converter("mock_type2", converter2) - - -def test_register_conflicting_struct_signature_raises(): - registry = semantic_registry.SemanticTypeRegistry() - python_type1 = Mock(name="PythonType1") - python_type2 = Mock(name="PythonType2") - struct_type = Mock(name="StructType") - converter1 = Mock() - converter1.python_type = python_type1 - converter1.arrow_struct_type = struct_type - registry.register_converter("mock_type1", converter1) - converter2 = Mock() - converter2.python_type = python_type2 - converter2.arrow_struct_type = struct_type - with pytest.raises(ValueError): - registry.register_converter("mock_type2", converter2) - - -def test_get_nonexistent_returns_none(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type = Mock(name="StructType") - assert registry.get_converter_for_semantic_type("not_present") is None - assert registry.get_converter_for_python_type(python_type) is None - assert registry.get_converter_for_struct_signature(struct_type) is None - - -def test_list_registered_types(): - registry = semantic_registry.SemanticTypeRegistry() - python_type1 = Mock(name="PythonType1") - struct_type1 = Mock(name="StructType1") - converter1 = Mock() - converter1.python_type = python_type1 - converter1.arrow_struct_type = struct_type1 - registry.register_converter("mock_type1", converter1) - - python_type2 = Mock(name="PythonType2") - struct_type2 = Mock(name="StructType2") - converter2 = Mock() - converter2.python_type = python_type2 - converter2.arrow_struct_type = struct_type2 - registry.register_converter("mock_type2", converter2) - - assert set(registry.list_semantic_types()) == {"mock_type1", "mock_type2"} - assert set(registry.list_python_types()) == {python_type1, python_type2} - assert set(registry.list_struct_signatures()) == {struct_type1, struct_type2} - - -def test_has_methods(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type = Mock(name="StructType") - converter = Mock() - converter.python_type = python_type - converter.arrow_struct_type = struct_type - registry.register_converter("mock_type", converter) - assert registry.has_semantic_type("mock_type") - assert registry.has_python_type(python_type) - assert registry.has_semantic_struct_signature(struct_type) - - -def test_integration_with_converter(): - registry = semantic_registry.SemanticTypeRegistry() - python_type = Mock(name="PythonType") - struct_type = Mock(name="StructType") - converter = Mock() - converter.python_type = python_type - converter.arrow_struct_type = struct_type - registry.register_converter("mock_type", converter) - retrieved = registry.get_converter_for_semantic_type("mock_type") - assert retrieved is converter - - -def test_uuid_type_registered_in_default_registry(): - """uuid.UUID should be registered and map to pa.struct([pa.field('uuid', pa.binary(16))]).""" - from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher - - hasher = get_versioned_semantic_arrow_hasher() - registry = hasher.semantic_registry - converter = registry.get_converter_for_python_type(uuid.UUID) - assert converter is not None - assert converter.arrow_struct_type == pa.struct([pa.field("uuid", pa.binary(16))]) - - -def test_uuid_struct_resolves_to_converter(): - """pa.struct([pa.field('uuid', pa.binary(16))]) should resolve back to a converter for uuid.UUID.""" - from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher - - hasher = get_versioned_semantic_arrow_hasher() - registry = hasher.semantic_registry - converter = registry.get_converter_for_struct_signature( - pa.struct([pa.field("uuid", pa.binary(16))]) - ) - assert converter is not None - assert converter.python_type is uuid.UUID - - -def test_uuid_semantic_type_name_registered(): - """Converter registered under the name 'uuid'.""" - from orcapod.hashing.versioned_hashers import get_versioned_semantic_arrow_hasher - - hasher = get_versioned_semantic_arrow_hasher() - registry = hasher.semantic_registry - converter = registry.get_converter_for_semantic_type("uuid") - assert converter is not None - assert converter.python_type is uuid.UUID - - -# Comprehensive unregister tests for future implementation -# Uncomment when unregister methods are implemented -# -# def test_unregister_by_semantic_type_name(): -# registry = semantic_registry.SemanticTypeRegistry() -# python_type = Mock(name="PythonType") -# struct_type = Mock(name="StructType") -# converter = Mock() -# converter.python_type = python_type -# converter.arrow_struct_type = struct_type -# registry.register_converter("mock_type", converter) -# result = registry.unregister_by_semantic_type_name("mock_type") -# assert result == {"mock_type": converter} -# assert not registry.has_semantic_type("mock_type") -# assert not registry.has_python_type(python_type) -# assert not registry.has_semantic_struct_signature(struct_type) -# assert registry.get_converter_for_semantic_type("mock_type") is None -# assert registry.get_converter_for_python_type(python_type) is None -# assert registry.get_converter_for_struct_signature(struct_type) is None -# -# def test_unregister_by_converter(): -# registry = semantic_registry.SemanticTypeRegistry() -# python_type = Mock(name="PythonType") -# struct_type = Mock(name="StructType") -# converter = Mock() -# converter.python_type = python_type -# converter.arrow_struct_type = struct_type -# registry.register_converter("mock_type", converter) -# result = registry.unregister_by_converter(converter) -# assert result == {"mock_type": converter} -# assert not registry.has_semantic_type("mock_type") -# assert not registry.has_python_type(python_type) -# assert not registry.has_semantic_struct_signature(struct_type) -# assert registry.get_converter_for_semantic_type("mock_type") is None -# assert registry.get_converter_for_python_type(python_type) is None -# assert registry.get_converter_for_struct_signature(struct_type) is None -# -# def test_unregister_by_python_type(): -# registry = semantic_registry.SemanticTypeRegistry() -# python_type = Mock(name="PythonType") -# struct_type = Mock(name="StructType") -# converter = Mock() -# converter.python_type = python_type -# converter.arrow_struct_type = struct_type -# registry.register_converter("mock_type", converter) -# result = registry.unregister_by_python_type(python_type) -# assert result == {"mock_type": converter} -# assert not registry.has_semantic_type("mock_type") -# assert not registry.has_python_type(python_type) -# assert not registry.has_semantic_struct_signature(struct_type) -# assert registry.get_converter_for_semantic_type("mock_type") is None -# assert registry.get_converter_for_python_type(python_type) is None -# assert registry.get_converter_for_struct_signature(struct_type) is None -# -# def test_unregister_by_struct_signature(): -# registry = semantic_registry.SemanticTypeRegistry() -# python_type = Mock(name="PythonType") -# struct_type = Mock(name="StructType") -# converter = Mock() -# converter.python_type = python_type -# converter.arrow_struct_type = struct_type -# registry.register_converter("mock_type", converter) -# result = registry.unregister_by_struct_signature(struct_type) -# assert result == {"mock_type": converter} -# assert not registry.has_semantic_type("mock_type") -# assert not registry.has_python_type(python_type) -# assert not registry.has_semantic_struct_signature(struct_type) -# assert registry.get_converter_for_semantic_type("mock_type") is None -# assert registry.get_converter_for_python_type(python_type) is None -# assert registry.get_converter_for_struct_signature(struct_type) is None diff --git a/tests/test_semantic_types/test_semantic_struct_converters.py b/tests/test_semantic_types/test_semantic_struct_converters.py deleted file mode 100644 index 168f1a45..00000000 --- a/tests/test_semantic_types/test_semantic_struct_converters.py +++ /dev/null @@ -1,107 +0,0 @@ -from orcapod.semantic_types.semantic_struct_converters import ( - SemanticStructConverterBase, -) - - -class DummyConverter(SemanticStructConverterBase): - def __init__(self): - super().__init__("dummy") - self._python_type = dict - self._arrow_struct_type = "dummy_struct" - - @property - def python_type(self): - return self._python_type - - @property - def arrow_struct_type(self): - return self._arrow_struct_type - - def python_to_struct_dict(self, value): - return value - - def struct_dict_to_python(self, struct_dict): - return struct_dict - - def can_handle_python_type(self, python_type): - return python_type is dict - - def can_handle_struct_type(self, struct_type): - return struct_type == "dummy_struct" - - def is_semantic_struct(self, struct_dict): - return isinstance(struct_dict, dict) - - def hash_struct_dict(self, struct_dict): - return "dummyhash" - - -# --- SemanticStructConverterBase tests --- -def test_semantic_struct_converter_base_properties(): - converter = DummyConverter() - assert converter.semantic_type_name == "dummy" - assert converter.hasher_id == "dummy_content_sha256" - - - -def test_compute_content_hash(): - converter = DummyConverter() - data = b"abc" - result = converter._compute_content_hash(data) - import hashlib - - assert result.digest == hashlib.sha256(data).digest() - - -# --- PythonPathStructConverter tests --- - - -def test_extensibility_with_new_converter(): - class NewConverter(SemanticStructConverterBase): - def __init__(self): - super().__init__("newtype") - self._python_type = list - self._arrow_struct_type = "new_struct" - - @property - def python_type(self): - return self._python_type - - @property - def arrow_struct_type(self): - return self._arrow_struct_type - - def python_to_struct_dict(self, value): - return {"data": value} - - def struct_dict_to_python(self, struct_dict): - return struct_dict["data"] - - def can_handle_python_type(self, python_type): - return python_type is list - - def can_handle_struct_type(self, struct_type): - return struct_type == "new_struct" - - def is_semantic_struct(self, struct_dict): - return "data" in struct_dict - - def hash_struct_dict(self, struct_dict): - return "newhash" - - converter = NewConverter() - assert converter.semantic_type_name == "newtype" - assert converter.python_to_struct_dict([1, 2, 3]) == {"data": [1, 2, 3]} - assert converter.struct_dict_to_python({"data": [1, 2, 3]}) == [1, 2, 3] - assert converter.can_handle_python_type(list) - assert converter.can_handle_struct_type("new_struct") - assert converter.is_semantic_struct({"data": [1, 2, 3]}) - assert converter.hash_struct_dict({"data": [1, 2, 3]}) == "newhash" - - -# --- Edge cases --- -def test_dummy_converter_edge_cases(): - converter = DummyConverter() - assert converter.is_semantic_struct({}) - assert not converter.is_semantic_struct(None) - assert converter.hash_struct_dict({}) == "dummyhash" diff --git a/tests/test_semantic_types/test_universal_converter.py b/tests/test_semantic_types/test_universal_converter.py deleted file mode 100644 index 94f0edc8..00000000 --- a/tests/test_semantic_types/test_universal_converter.py +++ /dev/null @@ -1,630 +0,0 @@ -from datetime import datetime, timezone -from pathlib import Path -from typing import Any, cast - -import numpy as np -import pyarrow as pa -import pytest - -from orcapod.contexts import get_default_context -from orcapod.semantic_types import universal_converter -from orcapod.semantic_types.universal_converter import UniversalTypeConverter - - -def test_python_type_to_arrow_type_basic(): - assert universal_converter.python_type_to_arrow_type(int) == pa.int64() - assert universal_converter.python_type_to_arrow_type(float) == pa.float64() - assert universal_converter.python_type_to_arrow_type(str) == pa.large_string() - assert universal_converter.python_type_to_arrow_type(bool) == pa.bool_() - assert universal_converter.python_type_to_arrow_type(bytes) == pa.large_binary() - - -def test_python_type_to_arrow_type_datetime(): - assert universal_converter.python_type_to_arrow_type(datetime) == pa.timestamp( - "us", tz="UTC" - ) - - -def test_arrow_type_to_python_type_timestamp_with_tz(): - assert ( - universal_converter.arrow_type_to_python_type(pa.timestamp("us", tz="UTC")) - is datetime - ) - - -def test_arrow_type_to_python_type_timestamp_no_tz(): - assert universal_converter.arrow_type_to_python_type(pa.timestamp("us")) is datetime - - -def test_datetime_converter_rejects_naive(): - to_arrow, _ = universal_converter.get_conversion_functions(datetime) - naive = datetime(2024, 1, 15, 12, 30, 45, 123456) # no tzinfo - with pytest.raises(ValueError, match="Naive datetime"): - to_arrow(naive) - - -def test_datetime_converter_rejects_stub_tzinfo(): - """Rejects datetimes whose tzinfo.utcoffset() returns None (effectively naive).""" - import datetime as dt_mod - - class StubTzInfo(dt_mod.tzinfo): - def utcoffset(self, d): - return None # technically set but semantically naive - - def tzname(self, d): - return "Stub" - - def dst(self, d): - return None - - to_arrow, _ = universal_converter.get_conversion_functions(datetime) - stub_aware = datetime(2024, 1, 15, 12, 30, 45, tzinfo=StubTzInfo()) - with pytest.raises(ValueError, match="Naive datetime"): - to_arrow(stub_aware) - - -def test_datetime_converter_accepts_aware(): - to_arrow, _ = universal_converter.get_conversion_functions(datetime) - aware = datetime(2024, 1, 15, 12, 30, 45, 123456, tzinfo=timezone.utc) - result = to_arrow(aware) - assert result == aware - - -def test_datetime_converter_accepts_non_utc_aware(): - """Non-UTC timezone-aware datetimes pass through the converter unchanged. - - PyArrow normalises the value to UTC when writing to a pa.timestamp("us", tz="UTC") - column; the converter itself does not normalise — it only enforces the timezone - policy for naive datetimes. - """ - import zoneinfo - - to_arrow, _ = universal_converter.get_conversion_functions(datetime) - eastern = zoneinfo.ZoneInfo("America/New_York") - non_utc = datetime(2024, 1, 15, 12, 30, 45, tzinfo=eastern) - result = to_arrow(non_utc) - assert result == non_utc # converter passes through unchanged - - -def test_datetime_converter_passes_none_through(): - """None passes through the datetime converter unchanged (PyArrow enforces nullability).""" - to_arrow, _ = universal_converter.get_conversion_functions(datetime) - assert to_arrow(None) is None - - -def test_tz_less_arrow_timestamp_reads_as_naive(): - """Reading a tz-less Arrow timestamp column produces naive (timezone-less) datetimes. - - PyArrow's ``.as_py()`` on a tz-less timestamp returns a naive datetime. The - converter passes it through unchanged — no UTC attachment. To write these values - back via the converter use the ``"coerce_utc"`` timezone policy, or attach timezone - info manually before calling ``python_dicts_to_arrow_table``. - """ - converter = get_default_context().type_converter - naive_ts = datetime(2024, 5, 1, 9, 0, 0) - table = pa.table({"ts": pa.array([naive_ts], type=pa.timestamp("us"))}) - - rows_out = converter.arrow_table_to_python_dicts(table) - result = rows_out[0]["ts"] - - assert result.tzinfo is None - assert result == datetime(2024, 5, 1, 9, 0, 0) - - -def test_datetime_coerce_utc_converts_naive(): - """coerce_utc policy attaches timezone.utc to naive datetimes instead of raising.""" - converter = UniversalTypeConverter(datetime_timezone="coerce_utc") - to_arrow = converter.get_python_to_arrow_converter(datetime) - naive = datetime(2024, 1, 15, 12, 30, 45, 123456) - result = to_arrow(naive) - assert result == datetime(2024, 1, 15, 12, 30, 45, 123456, tzinfo=timezone.utc) - - -def test_datetime_coerce_utc_preserves_aware(): - """coerce_utc policy leaves already-aware datetimes unchanged.""" - converter = UniversalTypeConverter(datetime_timezone="coerce_utc") - to_arrow = converter.get_python_to_arrow_converter(datetime) - aware = datetime(2024, 1, 15, 12, 30, 45, 123456, tzinfo=timezone.utc) - result = to_arrow(aware) - assert result == aware - - -def test_datetime_round_trip(): - converter = get_default_context().type_converter - ts = datetime(2024, 3, 15, 10, 30, 45, 123456, tzinfo=timezone.utc) - rows_in = [{"event": "launch", "ts": ts}] - - # No explicit schema — exercises schema inference from data (type(value) -> datetime) - table = converter.python_dicts_to_arrow_table(rows_in) - - # Arrow schema must use timestamp(us, UTC) and be non-nullable for a plain datetime field - assert table.schema.field("ts").type == pa.timestamp("us", tz="UTC") - assert table.schema.field("ts").nullable is False - - rows_out = converter.arrow_table_to_python_dicts(table) - assert len(rows_out) == 1 - assert rows_out[0]["event"] == "launch" - assert rows_out[0]["ts"] == ts - - -def test_optional_datetime_round_trip(): - converter = get_default_context().type_converter - ts = datetime(2024, 6, 1, 0, 0, 0, tzinfo=timezone.utc) - rows_in = [ - {"label": "a", "ts": ts}, - {"label": "b", "ts": None}, - ] - python_schema = {"label": str, "ts": datetime | None} - - table = converter.python_dicts_to_arrow_table(rows_in, python_schema=python_schema) - - assert table.schema.field("ts").type == pa.timestamp("us", tz="UTC") - assert table.schema.field("ts").nullable is True - - rows_out = converter.arrow_table_to_python_dicts(table) - assert rows_out[0]["ts"] == ts - assert rows_out[1]["ts"] is None - - -def test_python_type_to_arrow_type_numpy(): - assert universal_converter.python_type_to_arrow_type(np.int32) == pa.int32() - assert universal_converter.python_type_to_arrow_type(np.float64) == pa.float64() - assert universal_converter.python_type_to_arrow_type(np.bool_) == pa.bool_() - - -def test_python_type_to_arrow_type_custom(): - arrow_type = universal_converter.python_type_to_arrow_type(Path) - # Should be a StructType with field 'path' of type large_string - assert isinstance(arrow_type, pa.StructType) - assert len(arrow_type) == 1 - field = arrow_type[0] - assert field.name == "path" - assert field.type == pa.large_string() - - -def test_python_type_to_arrow_type_upath(): - from upath import UPath - - arrow_type = universal_converter.python_type_to_arrow_type(UPath) - # Should be a StructType with field 'upath' of type large_string - assert isinstance(arrow_type, pa.StructType) - assert len(arrow_type) == 1 - field = arrow_type[0] - assert field.name == "upath" - assert field.type == pa.large_string() - - -def test_optional_upath_converter(): - """Test that Optional[UPath] correctly converts UPath values.""" - from upath import UPath - - to_arrow, to_python = universal_converter.get_conversion_functions(UPath | None) - - # Test with UPath value - path = UPath("/tmp/test.txt") - result = to_arrow(path) - assert result == {"upath": "/tmp/test.txt"} - - # Test with None - assert to_arrow(None) is None - - -def test_complex_union_raises_error(): - """Test that complex unions (multiple non-None types) raise ValueError.""" - from upath import UPath - - with pytest.raises(ValueError, match="Complex unions"): - universal_converter.get_conversion_functions(UPath | Path) - - with pytest.raises(ValueError, match="Complex unions"): - universal_converter.python_type_to_arrow_type(UPath | Path) - - -def test_python_type_to_arrow_type_context(): - ctx = get_default_context() - assert universal_converter.python_type_to_arrow_type(int, ctx) == pa.int64() - - -def test_python_type_to_arrow_type_unsupported(): - class CustomType: - pass - - with pytest.raises(Exception): - universal_converter.python_type_to_arrow_type(CustomType) - - -def test_arrow_type_to_python_type_basic(): - assert universal_converter.arrow_type_to_python_type(pa.int64()) is int - assert universal_converter.arrow_type_to_python_type(pa.float64()) is float - assert universal_converter.arrow_type_to_python_type(pa.large_string()) is str - assert universal_converter.arrow_type_to_python_type(pa.bool_()) is bool - assert universal_converter.arrow_type_to_python_type(pa.large_binary()) is bytes - - -def test_arrow_type_to_python_type_context(): - ctx = get_default_context() - assert universal_converter.arrow_type_to_python_type(pa.int64(), ctx) is int - - -def test_arrow_type_to_python_type_unsupported(): - class FakeArrowType: - pass - - with pytest.raises(Exception): - universal_converter.arrow_type_to_python_type( - cast(pa.DataType, FakeArrowType()) - ) - - -def test_get_conversion_functions_basic(): - to_arrow, to_python = universal_converter.get_conversion_functions(int) - assert callable(to_arrow) - assert callable(to_python) - assert to_arrow(42) == 42 - assert to_python(42) == 42 - - -def test_get_conversion_functions_custom(): - to_arrow, to_python = universal_converter.get_conversion_functions(str) - assert to_arrow("abc") == "abc" - assert to_python("abc") == "abc" - - -def test_get_conversion_functions_context(): - ctx = get_default_context() - to_arrow, to_python = universal_converter.get_conversion_functions(float, ctx) - assert to_arrow(1.5) == 1.5 - assert to_python(1.5) == 1.5 - - -def test_python_type_to_arrow_type_list(): - # Unparameterized list should raise ValueError - with pytest.raises(ValueError): - universal_converter.python_type_to_arrow_type(list) - - -def test_python_type_to_arrow_type_dict(): - # Unparameterized dict should raise ValueError - with pytest.raises(ValueError): - universal_converter.python_type_to_arrow_type(dict) - - -def test_python_type_to_arrow_type_list_of_dict(): - # For list[dict[str, int]], expect LargeListType of LargeListType of StructType - arrow_type = universal_converter.python_type_to_arrow_type(list[dict[str, int]]) - # Should be LargeListType - assert arrow_type.__class__.__name__.endswith("ListType") - # Next level should also be LargeListType - arrow_type = cast(pa.ListType, arrow_type) - inner_list = arrow_type.value_type - assert inner_list.__class__.__name__.endswith("ListType") - # Innermost should be StructType - struct_type = inner_list.value_type - assert isinstance(struct_type, pa.StructType) - assert struct_type[0].name == "key" - assert struct_type[0].type == pa.large_string() - assert struct_type[1].name == "value" - assert struct_type[1].type == pa.int64() - - -def test_python_type_to_arrow_type_dict_of_list(): - # dict[str, list[int]] should be a LargeListType of StructType, with value field as LargeListType - arrow_type = universal_converter.python_type_to_arrow_type(dict[str, list[int]]) - assert arrow_type.__class__.__name__.endswith("ListType") - arrow_type = cast(pa.ListType, arrow_type) - struct_type = arrow_type.value_type - assert isinstance(struct_type, pa.StructType) - assert struct_type[0].name == "key" - assert struct_type[0].type == pa.large_string() - assert struct_type[1].name == "value" - value_type = struct_type[1].type - assert value_type.__class__.__name__.endswith("ListType") - assert value_type.value_type == pa.int64() - - -def test_python_type_to_arrow_type_list_of_list(): - arrow_type = universal_converter.python_type_to_arrow_type(list[list[int]]) - assert arrow_type.__class__.__name__.endswith("ListType") - arrow_type = cast(pa.ListType, arrow_type) - inner_list = arrow_type.value_type - assert inner_list.__class__.__name__.endswith("ListType") - assert inner_list.value_type == pa.int64() - - -def test_python_type_to_arrow_type_deeply_nested(): - # dict[str, list[list[dict[str, float]]]] - complex_type = dict[str, list[list[dict[str, float]]]] - arrow_type = universal_converter.python_type_to_arrow_type(complex_type) - # Should be a LargeListType of StructType - assert arrow_type.__class__.__name__.endswith("ListType") - arrow_type = cast(pa.ListType, arrow_type) - struct_type = arrow_type.value_type - assert isinstance(struct_type, pa.StructType) - assert struct_type[0].name == "key" - assert struct_type[0].type == pa.large_string() - assert struct_type[1].name == "value" - outer_list = struct_type[1].type - assert outer_list.__class__.__name__.endswith("ListType") - inner_list = outer_list.value_type - assert inner_list.__class__.__name__.endswith("ListType") - inner_struct_list = inner_list.value_type - assert inner_struct_list.__class__.__name__.endswith("ListType") - inner_struct = inner_struct_list.value_type - assert isinstance(inner_struct, pa.StructType) - assert inner_struct[0].name == "key" - assert inner_struct[0].type == pa.large_string() - assert inner_struct[1].name == "value" - assert inner_struct[1].type == pa.float64() - - -# Roundtrip tests for complex types -def test_roundtrip_list_of_int(): - py_val = [1, 2, 3, 4] - to_arrow, to_python = universal_converter.get_conversion_functions(list[int]) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - assert py_val == py_val2 - - -def test_roundtrip_dict_str_int(): - py_val = {"a": 1, "b": 2} - to_arrow, to_python = universal_converter.get_conversion_functions(dict[str, int]) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - # dict roundtrip may come back as dict or list of pairs - if isinstance(py_val2, dict): - assert py_val == py_val2 - else: - # Accept list of pairs - assert sorted(py_val.items()) == sorted( - [(d["key"], d["value"]) for d in py_val2] - ) - - -def test_roundtrip_list_of_list_of_float(): - py_val = [[1.1, 2.2], [3.3, 4.4]] - to_arrow, to_python = universal_converter.get_conversion_functions( - list[list[float]] - ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - assert py_val == py_val2 - - -def test_roundtrip_set_of_int(): - py_val = {1, 2, 3} - to_arrow, to_python = universal_converter.get_conversion_functions(set[int]) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - # set will come back as list - assert py_val != py_val2 - assert set(py_val) == set(py_val2) - - -def test_roundtrip_various_complex_types(): - cases = [ - ([1, 2, 3], list[int]), - ([["a", "b"], ["c"]], list[list[str]]), - ({"a": 1, "b": 2}, dict[str, int]), - ([{"x": 1.1, "y": 2.2}, {"x": 3.3, "y": 4.4}], list[dict[str, float]]), - ({"a": [1, 2], "b": [3]}, dict[str, list[int]]), - ( - [{"a": [1, 2]}, {"b": [3], "c": [4, 5, 6]}], - list[dict[str, list[int]]], - ), - ( - [[{"k": "a", "v": 1.1}, {"k": "b", "v": 2.2}], [{"k": "c", "v": 3.3}]], - list[list[dict[str, float]]], - ), - ( - {"outer": [{"inner": [1, 2]}, {"inner": [3, 4]}]}, - dict[str, list[dict[str, list[int]]]], - ), - ({"a": {"b": {"c": 42}}}, dict[str, dict[str, dict[str, int]]]), - ({"a": None, "b": 2}, dict[str, int]), - ( - [{"x": [1, 2], "y": [3, 4]}, {"x": [5], "y": [6, 7]}], - list[dict[str, list[int]]], - ), - ] - for py_val, typ in cases: - to_arrow, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - assert py_val == py_val2, f"Failed roundtrip for type {typ} with value {py_val}" - - -def test_incomplete_roundtrip_types(): - cases = [({"a": {1, 2}, "b": {3}}, dict[str, set[int]], {"a": [1, 2], "b": [3]})] - - for py_val, typ, expected_return in cases: - to_arrow, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - assert py_val2 == expected_return, ( - f"Failed roundtrip for type {typ} with value {py_val}" - ) - - -def test_roundtrip_minimal_key_list_issue(): - py_val = [{"test": [1, 2, 3], "next": [3, 4]}] - typ = list[dict[str, list[int]]] - to_arrow, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - print("Original:", py_val) - print("Roundtrip:", py_val2) - assert py_val == py_val2 - - -def test_roundtrip_simpler_key_issue_dict_str_list(): - py_val = {"a": [1, 2]} - typ = dict[str, list[int]] - to_arrow, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - print("Original dict[str, list[int]]:", py_val) - print("Roundtrip:", py_val2) - assert py_val == py_val2 - - -def test_roundtrip_simpler_key_issue_list_dict_str_int(): - py_val = [{"key": "a", "value": 1}] - typ = list[dict[str, int]] - to_arrow, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow(py_val) - py_val2 = to_python(arr) - print("Original list[dict[str, int]]:", py_val) - print("Roundtrip:", py_val2) - assert py_val == py_val2 - - -def test_inspect_arrow_schema_dict_str_list(): - py_val = {"test": [1, 2]} - typ = dict[str, list[int]] - arrow_type = universal_converter.python_type_to_arrow_type(typ) - print("Arrow type for dict[str, list[int]]:", arrow_type) - to_arrow_struct, to_python = universal_converter.get_conversion_functions(typ) - arr = to_arrow_struct(py_val) - assert arr == [{"key": "test", "value": [1, 2]}] - - -def test_schema_as_required_strips_optional_fields(): - from orcapod.types import Schema - - s = Schema({"a": int, "b": str}, optional_fields=["b"]) - result = s.as_required() - assert result == Schema({"a": int, "b": str}) - assert result.optional_fields == frozenset() - - -def test_schema_as_required_idempotent(): - from orcapod.types import Schema - - s = Schema({"a": int, "b": str}, optional_fields=["a", "b"]) - once = s.as_required() - twice = s.as_required().as_required() - assert once == twice - - -def test_python_schema_to_arrow_non_nullable(): - """Plain types (no | None) must produce nullable=False Arrow fields.""" - from orcapod.types import Schema - - ctx = get_default_context() - schema = ctx.type_converter.python_schema_to_arrow_schema( - Schema({"a": int, "b": str, "c": float, "d": bool, "e": bytes}) - ) - for name in ("a", "b", "c", "d", "e"): - assert schema.field(name).nullable is False, ( - f"Field '{name}' should be nullable=False for a plain type" - ) - - -def test_python_schema_to_arrow_optional_nullable(): - """Optional types (T | None) must produce nullable=True Arrow fields.""" - from orcapod.types import Schema - - ctx = get_default_context() - schema = ctx.type_converter.python_schema_to_arrow_schema( - Schema({"x": int | None, "y": str | None}) - ) - assert schema.field("x").nullable is True - assert schema.field("y").nullable is True - - -def test_arrow_schema_to_python_nullable_becomes_optional(): - """nullable=True Arrow fields must reconstruct as T | None.""" - ctx = get_default_context() - arrow_schema = pa.schema([pa.field("x", pa.int64(), nullable=True)]) - python_schema = ctx.type_converter.arrow_schema_to_python_schema(arrow_schema) - assert python_schema["x"] == int | None - - -def test_arrow_schema_to_python_non_nullable_stays_plain(): - """nullable=False Arrow fields must reconstruct as plain T.""" - ctx = get_default_context() - arrow_schema = pa.schema([pa.field("x", pa.int64(), nullable=False)]) - python_schema = ctx.type_converter.arrow_schema_to_python_schema(arrow_schema) - assert python_schema["x"] == int - - -def test_round_trip_preserves_optionality(): - """Python schema → Arrow → Python schema is lossless for nullable/non-nullable.""" - from orcapod.types import Schema - - ctx = get_default_context() - original = Schema({"required": int, "nullable_field": int | None}) - arrow = ctx.type_converter.python_schema_to_arrow_schema(original) - recovered = ctx.type_converter.arrow_schema_to_python_schema(arrow) - - assert recovered["required"] == int - assert recovered["nullable_field"] == int | None - assert recovered == original - - -# --------------------------------------------------------------------------- -# ENG-389: Any <-> pa.null() round-trip -# --------------------------------------------------------------------------- - - -def test_any_to_arrow_type(): - """typing.Any maps to pa.null().""" - assert universal_converter.python_type_to_arrow_type(Any) == pa.null() - - -def test_list_any_to_arrow_type(): - """list[Any] maps to pa.large_list(pa.null()).""" - assert ( - universal_converter.python_type_to_arrow_type(list[Any]) - == pa.large_list(pa.null()) - ) - - -def test_dict_any_any_to_arrow_type(): - """dict[Any, Any] maps to pa.large_list(pa.struct([("key", pa.null()), ("value", pa.null())])).""" - expected = pa.large_list( - pa.struct([("key", pa.null()), ("value", pa.null())]) - ) - assert universal_converter.python_type_to_arrow_type(dict[Any, Any]) == expected - - -def test_null_arrow_to_any_python_type(): - """pa.null() maps back to typing.Any.""" - assert universal_converter.arrow_type_to_python_type(pa.null()) is Any - - -def test_list_any_round_trip(): - """list[Any] round-trips: list[Any] -> pa.large_list(pa.null()) -> list[Any].""" - arrow_type = universal_converter.python_type_to_arrow_type(list[Any]) - assert universal_converter.arrow_type_to_python_type(arrow_type) == list[Any] - - -def test_dict_any_any_round_trip(): - """dict[Any, Any] round-trips through Arrow and back to dict[Any, Any].""" - arrow_type = universal_converter.python_type_to_arrow_type(dict[Any, Any]) - assert universal_converter.arrow_type_to_python_type(arrow_type) == dict[Any, Any] - - -def test_empty_container_inference_to_arrow_no_error(): - """Inferring schema from empty containers and converting to Arrow does not raise.""" - from orcapod.semantic_types.pydata_utils import infer_python_schema_from_pylist_data - from orcapod.semantic_types.universal_converter import UniversalTypeConverter - - schema = infer_python_schema_from_pylist_data([{"items": [], "meta": {}}]) - converter = UniversalTypeConverter() - # Must not raise ValueError: Unsupported Python type: typing.Any - arrow_schema = converter.python_schema_to_arrow_schema(schema) - assert "items" in [f.name for f in arrow_schema] - assert "meta" in [f.name for f in arrow_schema] - - -def test_pyarrow_empty_list_with_null_type(): - """PyArrow accepts empty lists for pa.large_list(pa.null()) and pa.large_list(pa.struct(...)) columns.""" - schema = pa.schema([ - pa.field("items", pa.large_list(pa.null())), - pa.field("meta", pa.large_list(pa.struct([("key", pa.null()), ("value", pa.null())]))), - ]) - table = pa.Table.from_pylist([{"items": [], "meta": []}], schema=schema) - assert table.num_rows == 1 - assert table.schema.field("items").type == pa.large_list(pa.null()) diff --git a/tests/test_semantic_types/test_upath_struct_converter.py b/tests/test_semantic_types/test_upath_struct_converter.py deleted file mode 100644 index ccfe014f..00000000 --- a/tests/test_semantic_types/test_upath_struct_converter.py +++ /dev/null @@ -1,148 +0,0 @@ -from pathlib import Path -from typing import cast - -import pytest -from upath import UPath - -from orcapod.hashing.file_hashers import BasicFileHasher -from orcapod.semantic_types.semantic_struct_converters import UPathStructConverter - - -@pytest.fixture -def file_hasher(): - return BasicFileHasher(algorithm="sha256") - - -@pytest.fixture -def converter(file_hasher): - return UPathStructConverter(file_hasher=file_hasher) - - -def test_upath_to_struct_and_back(converter): - path_obj = UPath("/tmp/test.txt") - struct_dict = converter.python_to_struct_dict(path_obj) - assert struct_dict["upath"] == str(path_obj) - restored = converter.struct_dict_to_python(struct_dict) - assert isinstance(restored, UPath) - assert str(restored) == str(path_obj) - - -def test_upath_to_struct_invalid_type(converter): - with pytest.raises(TypeError): - converter.python_to_struct_dict(Path("/tmp/test.txt")) # type: ignore - - -def test_struct_to_python_missing_field(converter): - with pytest.raises(ValueError): - converter.struct_dict_to_python({}) - - -def test_can_handle_python_type(converter): - assert converter.can_handle_python_type(UPath) - assert not converter.can_handle_python_type(str) - assert not converter.can_handle_python_type(Path) - - -def test_can_handle_struct_type(converter): - struct_type = converter.arrow_struct_type - assert converter.can_handle_struct_type(struct_type) - - -def test_is_semantic_struct(converter): - assert converter.is_semantic_struct({"upath": "/tmp/test.txt"}) - assert not converter.is_semantic_struct({"path": "/tmp/test.txt"}) - assert not converter.is_semantic_struct({"upath": 123}) - - -def test_hash_struct_dict_file_not_found(converter, tmp_path): - struct_dict = {"upath": str(tmp_path / "does_not_exist.txt")} - with pytest.raises(FileNotFoundError): - converter.hash_struct_dict(struct_dict) - - -def test_hash_struct_dict_is_directory(converter, tmp_path): - struct_dict = {"upath": str(tmp_path)} - with pytest.raises(IsADirectoryError): - converter.hash_struct_dict(struct_dict) - - -def test_hash_struct_dict_content_based(converter, tmp_path): - """Two distinct files with identical content produce the same hash.""" - file1 = tmp_path / "file1.txt" - file2 = tmp_path / "file2.txt" - content = "identical content" - file1.write_text(content) - file2.write_text(content) - hash1 = converter.hash_struct_dict({"upath": str(file1)}) - hash2 = converter.hash_struct_dict({"upath": str(file2)}) - assert hash1 == hash2 - - -def test_hash_struct_dict_with_prefix(converter, tmp_path): - """Hash always starts with 'upath:sha256:'.""" - file = tmp_path / "file.txt" - file.write_text("hello") - hash_str = converter.hash_struct_dict({"upath": str(file)}) - assert hash_str.startswith("upath:sha256:") - - -def test_hash_struct_dict_different_content(converter, tmp_path): - """Same path with modified content yields a different hash.""" - file = tmp_path / "mutable.txt" - file.write_text("version 1") - hash1 = converter.hash_struct_dict({"upath": str(file)}) - file.write_text("version 2") - hash2 = converter.hash_struct_dict({"upath": str(file)}) - assert hash1 != hash2 - - -def test_hash_struct_dict_missing_field(converter): - with pytest.raises(ValueError, match="Missing 'upath' field"): - converter.hash_struct_dict({}) - - -def test_upath_arrow_struct_type(converter): - """The Arrow struct type has a single 'upath' field of large_string.""" - import pyarrow as pa - - struct_type = converter.arrow_struct_type - assert isinstance(struct_type, pa.StructType) - assert len(struct_type) == 1 - assert struct_type[0].name == "upath" - assert struct_type[0].type == pa.large_string() - - -def test_path_and_upath_struct_types_differ(): - """Path and UPath converters produce distinct Arrow struct types.""" - from orcapod.semantic_types.semantic_struct_converters import PythonPathStructConverter - - file_hasher = BasicFileHasher(algorithm="sha256") - path_conv = PythonPathStructConverter(file_hasher=file_hasher) - upath_conv = UPathStructConverter(file_hasher=file_hasher) - - assert path_conv.arrow_struct_type != upath_conv.arrow_struct_type - assert path_conv.arrow_struct_type[0].name == "path" - assert upath_conv.arrow_struct_type[0].name == "upath" - - -def test_path_converter_rejects_upath(): - """PythonPathStructConverter rejects UPath instances to avoid ambiguity.""" - from orcapod.semantic_types.semantic_struct_converters import PythonPathStructConverter - - file_hasher = BasicFileHasher(algorithm="sha256") - path_conv = PythonPathStructConverter(file_hasher=file_hasher) - - upath_val = UPath("/tmp/test.txt") - with pytest.raises(TypeError, match="not UPath"): - path_conv.python_to_struct_dict(upath_val) - - -def test_path_converter_cannot_handle_upath_type(): - """PythonPathStructConverter.can_handle_python_type returns False for UPath.""" - from orcapod.semantic_types.semantic_struct_converters import PythonPathStructConverter - - file_hasher = BasicFileHasher(algorithm="sha256") - path_conv = PythonPathStructConverter(file_hasher=file_hasher) - - assert not path_conv.can_handle_python_type(UPath) - assert path_conv.can_handle_python_type(Path) diff --git a/tests/test_semantic_types/test_uuid_struct_converter.py b/tests/test_semantic_types/test_uuid_struct_converter.py deleted file mode 100644 index c8084991..00000000 --- a/tests/test_semantic_types/test_uuid_struct_converter.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Tests for UUIDStructConverter.""" -import uuid - -import pyarrow as pa -import pytest - -from orcapod.semantic_types.semantic_struct_converters import UUIDStructConverter - - -@pytest.fixture -def converter(): - return UUIDStructConverter() - - -@pytest.fixture -def sample_uuid(): - return uuid.UUID("550e8400-e29b-41d4-a716-446655440000") - - -def test_python_type(converter): - assert converter.python_type is uuid.UUID - - -def test_arrow_struct_type(converter): - assert converter.arrow_struct_type == pa.struct([pa.field("uuid", pa.binary(16))]) - - -def test_semantic_type_name(converter): - assert converter.semantic_type_name == "uuid" - - -def test_python_to_struct_dict(converter, sample_uuid): - result = converter.python_to_struct_dict(sample_uuid) - assert result == {"uuid": sample_uuid.bytes} - assert isinstance(result["uuid"], bytes) - assert len(result["uuid"]) == 16 - - -def test_python_to_struct_dict_rejects_non_uuid(converter): - with pytest.raises(TypeError): - converter.python_to_struct_dict("550e8400-e29b-41d4-a716-446655440000") # type: ignore - - -def test_struct_dict_to_python(converter, sample_uuid): - struct_dict = {"uuid": sample_uuid.bytes} - result = converter.struct_dict_to_python(struct_dict) - assert result == sample_uuid - assert isinstance(result, uuid.UUID) - - -def test_struct_dict_to_python_from_bytearray(converter, sample_uuid): - """Arrow may return binary fields as bytearray — must handle both.""" - struct_dict = {"uuid": bytearray(sample_uuid.bytes)} - result = converter.struct_dict_to_python(struct_dict) - assert result == sample_uuid - - -def test_struct_dict_to_python_missing_field(converter): - with pytest.raises(ValueError, match="Missing 'uuid' field"): - converter.struct_dict_to_python({}) - - -def test_round_trip(converter, sample_uuid): - struct_dict = converter.python_to_struct_dict(sample_uuid) - recovered = converter.struct_dict_to_python(struct_dict) - assert recovered == sample_uuid - - -def test_round_trip_all_versions(): - """Verify round-trip works for uuid4, uuid5, and uuid7 (uuid_utils). - - ``uuid_utils.UUID`` objects do not inherit from ``uuid.UUID`` and their - ``__eq__`` does not cross-compare with ``uuid.UUID``, so we compare by - the canonical string representation instead of direct equality. - """ - from uuid_utils import uuid7 - - converter = UUIDStructConverter() - for u in [uuid.uuid4(), uuid.uuid5(uuid.NAMESPACE_OID, "test"), uuid7()]: - recovered = converter.struct_dict_to_python(converter.python_to_struct_dict(u)) - assert str(recovered) == str(u) - - -def test_arrow_array_round_trip(converter, sample_uuid): - """Verify UUID survives a PyArrow array round-trip.""" - struct_dict = converter.python_to_struct_dict(sample_uuid) - arr = pa.array([struct_dict], type=pa.struct([pa.field("uuid", pa.binary(16))])) - recovered_dict = arr[0].as_py() - recovered_uuid = converter.struct_dict_to_python(recovered_dict) - assert recovered_uuid == sample_uuid - - -def test_distinct_uuids_produce_distinct_struct_dicts(converter): - u1, u2 = uuid.uuid4(), uuid.uuid4() - assert converter.python_to_struct_dict(u1) != converter.python_to_struct_dict(u2) - - -def test_can_handle_python_type_uuid(converter): - assert converter.can_handle_python_type(uuid.UUID) is True - - -def test_can_handle_python_type_rejects_str(converter): - assert converter.can_handle_python_type(str) is False - - -def test_can_handle_struct_type_uuid(converter): - assert converter.can_handle_struct_type(pa.struct([pa.field("uuid", pa.binary(16))])) is True - - -def test_can_handle_struct_type_rejects_other(converter): - import pyarrow as pa - - assert converter.can_handle_struct_type(pa.struct([pa.field("path", pa.large_string())])) is False - - -def test_hash_struct_dict_returns_string(converter, sample_uuid): - struct_dict = converter.python_to_struct_dict(sample_uuid) - result = converter.hash_struct_dict(struct_dict) - assert isinstance(result, str) - assert len(result) > 0 - - -def test_hash_struct_dict_consistent(converter, sample_uuid): - """Same UUID always produces the same hash.""" - struct_dict = converter.python_to_struct_dict(sample_uuid) - assert converter.hash_struct_dict(struct_dict) == converter.hash_struct_dict(struct_dict) - - -def test_hash_struct_dict_different_uuids(converter): - """Different UUIDs produce different hashes.""" - u1, u2 = uuid.uuid4(), uuid.uuid4() - d1 = converter.python_to_struct_dict(u1) - d2 = converter.python_to_struct_dict(u2) - assert converter.hash_struct_dict(d1) != converter.hash_struct_dict(d2) diff --git a/uv.lock b/uv.lock index 3d41afa3..9dcd253e 100644 --- a/uv.lock +++ b/uv.lock @@ -2301,6 +2301,7 @@ dependencies = [ { name = "pandas" }, { name = "polars" }, { name = "pyarrow" }, + { name = "pydantic" }, { name = "pygraphviz" }, { name = "pymongo" }, { name = "pyyaml" }, @@ -2376,10 +2377,11 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.10.3" }, { name = "networkx" }, { name = "pandas", specifier = ">=2.2.3" }, - { name = "polars", specifier = ">=1.31.0" }, + { name = "polars", specifier = ">=1.36.0" }, { name = "psycopg", extras = ["binary"], marker = "extra == 'all'", specifier = ">=3.0" }, { name = "psycopg", extras = ["binary"], marker = "extra == 'postgresql'", specifier = ">=3.0" }, { name = "pyarrow", specifier = ">=20.0.0" }, + { name = "pydantic", specifier = ">=2.0" }, { name = "pygraphviz", specifier = ">=1.14" }, { name = "pymongo", specifier = ">=4.15.5" }, { name = "pyspiral", marker = "extra == 'all'", specifier = ">=0.14.0" }, @@ -2694,16 +2696,30 @@ wheels = [ [[package]] name = "polars" -version = "1.31.0" +version = "1.41.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ff/f9/aeda46259b0669247a160315d2d51269de9504b9dd2f70acadbcb22f46b7/polars-1.41.2.tar.gz", hash = "sha256:256d6731162371b77f3f29a55eacb8c0fc740ddb1a293a01d2ef5b5393c5c708", size = 737996, upload-time = "2026-05-29T17:39:15.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/22/28f62d24f7db56ac4343588f9362d49b7b4177e55ac47a466fe696b0099b/polars-1.41.2-py3-none-any.whl", hash = "sha256:23ce9a2910b6e3e8d4258770bf44aa17170958df7af6e85feedf4458a04d8d29", size = 833445, upload-time = "2026-05-29T17:37:05.576Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.41.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fd/f5/de1b5ecd7d0bd0dd87aa392937f759f9cc3997c5866a9a7f94eabf37cd48/polars-1.31.0.tar.gz", hash = "sha256:59a88054a5fc0135386268ceefdbb6a6cc012d21b5b44fed4f1d3faabbdcbf32", size = 4681224, upload-time = "2025-06-18T12:00:46.24Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/56/54e3ea0e9b64f327179049e4742241cc6b1d3e8fa414b05a057dd26df367/polars_runtime_32-1.41.2.tar.gz", hash = "sha256:7af09ec1ab053da2c9669e8d15f809a4083a29be05db57111688b8051062af56", size = 2989474, upload-time = "2026-05-29T17:39:17.257Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/6e/bdd0937653c1e7a564a09ae3bc7757ce83fedbf19da600c8b35d62c0182a/polars-1.31.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ccc68cd6877deecd46b13cbd2663ca89ab2a2cb1fe49d5cfc66a9cef166566d9", size = 34511354, upload-time = "2025-06-18T11:59:40.048Z" }, - { url = "https://files.pythonhosted.org/packages/77/fe/81aaca3540c1a5530b4bc4fd7f1b6f77100243d7bb9b7ad3478b770d8b3e/polars-1.31.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:a94c5550df397ad3c2d6adc212e59fd93d9b044ec974dd3653e121e6487a7d21", size = 31377712, upload-time = "2025-06-18T11:59:45.104Z" }, - { url = "https://files.pythonhosted.org/packages/b8/d9/5e2753784ea30d84b3e769a56f5e50ac5a89c129e87baa16ac0773eb4ef7/polars-1.31.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada7940ed92bea65d5500ae7ac1f599798149df8faa5a6db150327c9ddbee4f1", size = 35050729, upload-time = "2025-06-18T11:59:48.538Z" }, - { url = "https://files.pythonhosted.org/packages/20/e8/a6bdfe7b687c1fe84bceb1f854c43415eaf0d2fdf3c679a9dc9c4776e462/polars-1.31.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:b324e6e3e8c6cc6593f9d72fe625f06af65e8d9d47c8686583585533a5e731e1", size = 32260836, upload-time = "2025-06-18T11:59:52.543Z" }, - { url = "https://files.pythonhosted.org/packages/6e/f6/9d9ad9dc4480d66502497e90ce29efc063373e1598f4bd9b6a38af3e08e7/polars-1.31.0-cp39-abi3-win_amd64.whl", hash = "sha256:3fd874d3432fc932863e8cceff2cff8a12a51976b053f2eb6326a0672134a632", size = 35156211, upload-time = "2025-06-18T11:59:55.805Z" }, - { url = "https://files.pythonhosted.org/packages/40/4b/0673a68ac4d6527fac951970e929c3b4440c654f994f0c957bd5556deb38/polars-1.31.0-cp39-abi3-win_arm64.whl", hash = "sha256:62ef23bb9d10dca4c2b945979f9a50812ac4ace4ed9e158a6b5d32a7322e6f75", size = 31469078, upload-time = "2025-06-18T11:59:59.242Z" }, + { url = "https://files.pythonhosted.org/packages/d6/9b/fe72a3811c0357cdb06c67bdc7695fa1623ad47948fc523195f5ac31037f/polars_runtime_32-1.41.2-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:95a08346dac337357cdb825c8076df7d36da54c4caa59a5cb41d0a30691c5edd", size = 52265283, upload-time = "2026-05-29T17:37:09.407Z" }, + { url = "https://files.pythonhosted.org/packages/0a/93/fab9da803fd80d9e83ef88c20932f637a10bc611b20415fc322eec84bc44/polars_runtime_32-1.41.2-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:dedfaeec2c7f995298da7319dd9431d662e5dd1d0ec51b1459df4a0234ceff52", size = 46571222, upload-time = "2026-05-29T17:37:13.698Z" }, + { url = "https://files.pythonhosted.org/packages/c8/2a/8843f34a8ac57acd058a39b87b03b580dd352a490e9dae0415e02033bdd4/polars_runtime_32-1.41.2-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18eea22c5cc34e27f8a60950458ad81e6a9ea75e89363ca1367e14e7e7f781fc", size = 50409372, upload-time = "2026-05-29T17:37:17.875Z" }, + { url = "https://files.pythonhosted.org/packages/6c/c6/92b352fe88cf51bd0a19fb99e1c0cbe46aa26c14dcf7995b89869cd932ae/polars_runtime_32-1.41.2-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2630540dfdfb0f36f9b04a07c7c2e3f50bf2ad384113263c1c812007ee9141e0", size = 56405484, upload-time = "2026-05-29T17:37:22.684Z" }, + { url = "https://files.pythonhosted.org/packages/74/c4/bae3174c3b02f6b441d2e58594387abcd509f67a098f682a83b195f08966/polars_runtime_32-1.41.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:20e969e08f9b137e233c04cc04de73d9795f89eb77d34854e40a025965a43763", size = 50603512, upload-time = "2026-05-29T17:37:27.422Z" }, + { url = "https://files.pythonhosted.org/packages/f4/ed/f2d26ae02d92c2689056838ed59e2a626326ad23c2831d58637d25f6c82a/polars_runtime_32-1.41.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e7016a3deb641b64a31447abbbee0f34bd020a6a9ae34ee6b743837def15e2a4", size = 54328561, upload-time = "2026-05-29T17:37:32.587Z" }, + { url = "https://files.pythonhosted.org/packages/9b/c4/9c3831cc885dc7769e59abf8f583821a5fb4403fd0e4eba0ccc6d47a3d4b/polars_runtime_32-1.41.2-cp310-abi3-win_amd64.whl", hash = "sha256:1e5e5377c315e0dcafdfb2a31adc546abbaeb3f9cb1864e6536523d2af473265", size = 51978643, upload-time = "2026-05-29T17:37:37.443Z" }, + { url = "https://files.pythonhosted.org/packages/cd/c6/79e9f3f270270d7ed5575d92b7bfef49f01abd9275447161275b23b553a8/polars_runtime_32-1.41.2-cp310-abi3-win_arm64.whl", hash = "sha256:843d96f69d18eca53429c1198e58891db7f18111f83b9c419bb45ad9d73eaed5", size = 46006901, upload-time = "2026-05-29T17:37:42.522Z" }, ] [[package]]