Skip to content

feat: add model serialization API with version metadata (#892)#929

Open
genrichez wants to merge 1 commit into
uber:masterfrom
genrichez:feature/model-serialization
Open

feat: add model serialization API with version metadata (#892)#929
genrichez wants to merge 1 commit into
uber:masterfrom
genrichez:feature/model-serialization

Conversation

@genrichez

Copy link
Copy Markdown

Closes #892.

Summary

Adds save() and load() methods to all causal meta-learners (T, S, X, R, DR) via a SerializableLearner mixin on BaseLearner. Models are persisted using joblib and include metadata (causalml version, python version, class name, timestamp) to prevent stale model execution in production.

Usage

from causalml.inference.meta import BaseTRegressor, load_learner

# Train
learner = BaseTRegressor(learner=XGBRegressor())
learner.fit(X=X, treatment=treatment, y=y)

# Save
learner.save("model.causalml")

# Load (type-safe, checks class matches)
loaded = BaseTRegressor.load("model.causalml")

# Load (generic, no class check)
loaded = load_learner("model.causalml")

Safety features

  • Version mismatch warning: if the model was saved with a different causalml version, a CausalMLVersionMismatchWarning is raised on load
  • Class mismatch error: loading a T-learner file as an S-learner raises ValueError
  • Unfitted guard: calling save() before fit() raises ValueError
  • Missing file: FileNotFoundError on bad path
  • Backwards compatible: loading a raw joblib dump (no metadata wrapper) works with a warning

Files changed

  • causalml/inference/meta/serialization.py (new) - the SerializableLearner mixin
  • causalml/inference/meta/base.py - added mixin to BaseLearner inheritance
  • causalml/inference/meta/__init__.py - exported load_learner and CausalMLVersionMismatchWarning
  • tests/test_serialization.py (new) - 19 unit tests

Testing

  • 19 unit tests covering round-trips for all learner types, safety checks, metadata validation, and edge cases
  • Verified existing meta-learner tests pass without modification
  • Tested with production scenarios: train-serve separation, champion-challenger deployment, model versioning for drift detection, compliance audit trail extraction, and checkpoint rollback

Adds save()/load() methods to all meta-learners via a SerializableLearner
mixin on BaseLearner. Models are persisted with joblib and include metadata
(causalml version, python version, class name, timestamp) to prevent stale
model execution in production.

API:
  learner.save('model.causalml')
  loaded = BaseTRegressor.load('model.causalml')
  loaded = load_learner('model.causalml')  # generic, no class check

Safety:
  - Version mismatch warning on load
  - Class mismatch raises ValueError
  - Unfitted model save raises ValueError
  - Missing file raises FileNotFoundError
  - Raw joblib fallback with warning
Signed-off-by: genrichez <genrichez@users.noreply.github.com>
@CLAassistant

CLAassistant commented Jul 4, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Enhancement] Model Serialization & Persistence for Trained Learners via save/load API

2 participants