Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"polars>=1.31.0",
"beartype>=0.21.0",
"deltalake>=1.0.2",
"pydantic>=2",
"graphviz>=0.21",
"gitpython>=3.1.45",
"universal-pathlib>=0.3.8",
Expand Down
4 changes: 4 additions & 0 deletions src/orcapod/contexts/data/v0.1.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
"_config": {
"file_hasher": {"_ref": "file_hasher"}
}
},
"pydantic": {
"_class": "orcapod.pydantic_config.PydanticModelConverter",
"_config": {}
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/orcapod/hashing/versioned_hashers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def get_versioned_semantic_arrow_hasher(
registry: Any = SemanticTypeRegistry()
file_hasher = BasicFileHasher(algorithm="sha256")
path_converter: Any = PythonPathStructConverter(file_hasher=file_hasher)
# NOTE: keep this converter list in sync with the production registry in
# src/orcapod/contexts/data/v0.1.json (semantic_registry._config.converters).
registry.register_converter("path", path_converter)
uuid_converter: Any = UUIDStructConverter()
registry.register_converter("uuid", uuid_converter)

from orcapod.pydantic_config import PydanticModelConverter

registry.register_converter("pydantic", PydanticModelConverter())

logger.debug(
"get_versioned_semantic_arrow_hasher: creating StarfixArrowHasher "
"(hasher_id=%r)",
Expand Down
193 changes: 193 additions & 0 deletions src/orcapod/pydantic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Pydantic-backed config loading for orcapod pipelines (ENG-607).

Provides `load_pydantic_config` (validate a YAML file against a pydantic model)
and `OrcapodBaseConfig` (a strict base for config schemas). A companion
`PydanticModelConverter` (also in this module) makes a validated model a
first-class, content-hashed orcapod value.
"""

from __future__ import annotations

import importlib
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar

import pydantic
import yaml
from upath import UPath

from orcapod.semantic_types.semantic_struct_converters import SemanticStructConverterBase
from orcapod.utils.lazy_module import LazyModule

if TYPE_CHECKING:
import pyarrow as pa
else:
pa = LazyModule("pyarrow")

M = TypeVar("M", bound=pydantic.BaseModel)


class OrcapodBaseConfig(pydantic.BaseModel):
"""Recommended base for pipeline config schemas.

Defaults to strict validation: unknown keys are rejected and instances are
immutable. Subclass this for pipeline configs; subclass `pydantic.BaseModel`
directly only when different semantics are required.
"""

model_config = pydantic.ConfigDict(extra="forbid", frozen=True)


def load_pydantic_config(path: str | Path | UPath, model_cls: type[M]) -> M:
"""Read a YAML file and validate it against a pydantic model.

The path is resolved through ``UPath``, so local paths and remote object
storage (e.g. ``s3://``, ``gs://``) are both supported.

Args:
path: Path to the YAML config file. A local path or any ``UPath``-supported
URI (e.g. an object-storage location).
model_cls: The pydantic model class to validate against.

Returns:
A validated instance of `model_cls`.

Raises:
ValueError: If the file cannot be read, the YAML cannot be parsed, or
validation fails. The error message includes the file path and the
underlying detail.
"""
path = UPath(path)
try:
text = path.read_text(encoding="utf-8")
except OSError as e:
raise ValueError(f"Could not read YAML config {path}: {e}") from e

try:
data = yaml.safe_load(text)
except yaml.YAMLError as e:
raise ValueError(f"Could not parse YAML config {path}: {e}") from e

try:
return model_cls.model_validate(data)
except pydantic.ValidationError as e:
raise ValueError(f"Config validation failed for {path}:\n{e}") from e


# Arrow struct field names for the serialized config.
_MODEL_FIELD = "__pydantic_model__" # fully-qualified "module:QualName"
_JSON_FIELD = "__pydantic_json__" # canonical JSON of the model


def _qualified_name(cls: type) -> str:
return f"{cls.__module__}:{cls.__qualname__}"


def _import_model(qualified_name: str) -> type[pydantic.BaseModel]:
module_path, _, qualname = qualified_name.partition(":")
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(
f"Cannot import module '{module_path}' for pydantic model "
f"'{qualified_name}': {e}"
) from e
obj: Any = module
for part in qualname.split("."):
try:
obj = getattr(obj, part)
except AttributeError as e:
raise ImportError(
f"Cannot resolve '{part}' in '{qualified_name}': {e}"
) from e
return obj


class PydanticModelConverter(SemanticStructConverterBase):
"""Semantic-type converter for pydantic models.

Maps any `pydantic.BaseModel` instance to an Arrow struct holding the
model's fully-qualified class name and its canonical JSON, and back. Content
is hashed over (class name + sorted-key canonical JSON), so identity tracks
the config's meaning rather than source-file formatting or dict key order.
Modeled on `PythonPathStructConverter`.
"""

def __init__(self) -> None:
super().__init__("pydantic")
self._arrow_struct_type = pa.struct(
[
pa.field(_MODEL_FIELD, pa.large_string()),
pa.field(_JSON_FIELD, pa.large_string()),
]
)

@property
def python_type(self) -> type:
return pydantic.BaseModel

@property
def arrow_struct_type(self) -> "pa.StructType":
return self._arrow_struct_type

def can_handle_python_type(self, python_type: type) -> bool:
return isinstance(python_type, type) and issubclass(
python_type, pydantic.BaseModel
)

def can_handle_struct_type(self, struct_type: Any) -> bool:
if not pa.types.is_struct(struct_type):
return False
for field in self._arrow_struct_type:
if (
field.name not in struct_type.names
or struct_type[field.name].type != field.type
):
return False
return True

def is_semantic_struct(self, struct_dict: dict[str, Any]) -> bool:
return (
set(struct_dict.keys()) == {_MODEL_FIELD, _JSON_FIELD}
and isinstance(struct_dict[_MODEL_FIELD], str)
and isinstance(struct_dict[_JSON_FIELD], str)
)

def python_to_struct_dict(self, value: Any) -> dict[str, Any]:
if not isinstance(value, pydantic.BaseModel):
raise TypeError(f"Expected a pydantic BaseModel, got {type(value)}")
return {
_MODEL_FIELD: _qualified_name(type(value)),
# model_dump_json() serialises fields in definition order (pydantic v2),
# so equal models produce identical JSON -> stable content hash.
_JSON_FIELD: value.model_dump_json(),
}

def struct_dict_to_python(self, struct_dict: dict[str, Any]) -> Any:
qualified_name = struct_dict.get(_MODEL_FIELD)
json_str = struct_dict.get(_JSON_FIELD)
if qualified_name is None or json_str is None:
raise ValueError(
f"Missing '{_MODEL_FIELD}'/'{_JSON_FIELD}' in struct dict"
)
model_cls = _import_model(qualified_name)
return model_cls.model_validate_json(json_str)

def hash_struct_dict(
self, struct_dict: dict[str, Any], add_prefix: bool = False
) -> str:
qualified_name = struct_dict.get(_MODEL_FIELD)
json_str = struct_dict.get(_JSON_FIELD)
if qualified_name is None or json_str is None:
raise ValueError(
f"Missing '{_MODEL_FIELD}'/'{_JSON_FIELD}' in struct dict"
)
# Canonicalize (sorted keys) so semantically-equal configs that differ only
# in dict key order hash equal -- identity tracks meaning, not formatting.
canonical_json = json.dumps(
json.loads(json_str), sort_keys=True, separators=(",", ":")
)
content = f"{qualified_name}\n{canonical_json}".encode("utf-8")
content_hash = self._compute_content_hash(content)
return self._format_hash_string(content_hash.digest, add_prefix=add_prefix)
Loading
Loading