Skip to content

Scoring rules for filters#259

Open
mattlevine22 wants to merge 12 commits into
mainfrom
ml-filter-scoring
Open

Scoring rules for filters#259
mattlevine22 wants to merge 12 commits into
mainfrom
ml-filter-scoring

Conversation

@mattlevine22

@mattlevine22 mattlevine22 commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR adds first-class support for observation scoring on filter-predicted observation distributions, focused on the continuous-time cd_dynamax Gaussian filter path. Users can now attach a scoring_config to Filter(...) to compute proper scoring rules at each observation time, while keeping predictive-summary recording as a separate concern.

Dependency

Relies on CD Dynamax PR.

What’s included

Behavior / API notes

  • Scoring is defined on the one-step-ahead predictive observation distribution, not on the filtered state posterior.
  • Scoring and predicted-observation recording are separate:
    • users can score without recording predicted means/covariances/ensembles
    • users can record predicted means/covariances/ensembles without scoring
  • In the Filter(...) handler path, scoring only does work when the score arrays will actually be surfaced as NumPyro sites. If record_as_numpyro_sites=False, the handler path skips score computation entirely.
  • ObservationScoringConfig.sample_seed is now the single scoring-level seed for any synthetic predictive sampling performed by Dynestyx during scoring, including:
    • adding observation noise to a latent predictive ensemble
    • drawing predictive observation samples from Gaussian moments for EnergyScore
  • Rules that only need predictive moments (GaussianLogProbScore, DawidSebastianiScore, ObservationWiseCRPSScore) do not depend on ensemble availability or sample_source.
  • Rules that need predictive samples (EnergyScore) use sample_source to choose between:
    • a backend-provided predictive observation ensemble
    • a predictive observation ensemble synthesized by adding observation noise to a latent ensemble
    • predictive samples drawn from Gaussian predictive moments

Performance / implementation details

  • Added a structured fast path for fixed observation noise covariances:
    • if the observation model is LinearGaussianObservation or GaussianObservation
    • and R is non-callable
    • the covariance is broadcast across time directly instead of recomputed per observation
  • Fallback per-time covariance construction uses jax.lax.map rather than a Python loop.
  • EnergyScore supports both:
    • fast vectorized pairwise computation
    • lower-memory scan-based pairwise computation via vectorized_pairwise=False

Docs and tutorials

Testing

  • Added a dedicated scoring test suite in tests/test_filter_scoring.py
  • Coverage includes:
    • score-site correctness against backend outputs
    • predicted-observation recording correctness
    • Gaussian vs. ensemble-based scoring paths
    • unsupported/skip behavior
    • synthetic sampling from Gaussian moments
    • backend observation ensemble precedence
    • fast-path/fallback observation covariance behavior
    • vectorized vs. scan equivalence for EnergyScore

Dependency / build notes

  • cd-dynamax is pinned to the required upstream commit SHA in pyproject.toml while waiting on the upstream release path.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds first-class support for computing and (optionally) recording proper scoring rules for one-step-ahead predicted observation distributions produced by continuous-time CD-Dynamax Gaussian filters, including integration into the Filter handler, tests, and documentation updates.

Changes:

  • Introduces dynestyx.inference.scoring (score definitions + ObservationScoringConfig) and dynestyx.inference.observation_predictions (backend-to-canonical prediction enrichment + trace recording).
  • Wires scoring/enrichment through the continuous-time CD-Dynamax filter path and the Filter handler (including batched/plate execution).
  • Adds a comprehensive test suite for scoring/recording behavior and updates tutorials/API docs navigation to include the new scoring topic.

Reviewed changes

Copilot reviewed 15 out of 16 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_filter_scoring.py New tests covering scoring rule outputs and trace recording behavior across continuous-time filter configs.
pyproject.toml Switches cd-dynamax dependency to a Git branch reference needed for the new backend outputs.
mkdocs.yml Adds tutorial + API nav entries for observation scoring and related missing-observations tutorials.
dynestyx/inference/scoring.py New scoring-rule implementations and scoring configuration dataclass.
dynestyx/inference/observation_predictions.py New canonicalization/enrichment layer to derive prediction summaries and scores from backend outputs and record them into the trace.
dynestyx/inference/integrations/cd_dynamax/continuous_filter.py Plumbs scoring_config into the continuous-time CD-Dynamax filter run and records prediction/score sites.
dynestyx/inference/filters.py Adds scoring_config to the Filter handler and enforces current support constraints; wires scoring through continuous-time paths and plate/batched execution.
dynestyx/inference/filter_configs.py Adds record_predicted_observations_* fields to filter configs and includes them in recording kwargs.
dynestyx/inference/init.py Exposes the new scoring module at the package level.
docs/tutorials/gentle_intro/11c_missing_observations_hmms.ipynb Updates tutorial “Next” navigation to point to the new scoring tutorial.
docs/tutorials/gentle_intro/00_index.ipynb Adds the Part 12 scoring tutorial to the gentle intro index.
docs/api_reference/public/inference/filters.md Documents Filter scoring support and links to the Scoring page.
docs/api_reference/public/inference/filter_configs.md Mentions predicted-observation recording fields and links to the Scoring page.
docs/api_reference/developer/inference/filters.md Developer-facing note about scoring entry point and where backend translation lives.
docs/api_reference/developer/inference/filter_configs.md Developer-facing note about predicted-observation recording fields and scoring linkage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +421 to +438
if scoring_config.target == "latent_predictive":
if scoring_config.sample_source in {"auto", "backend_ensemble"}:
return predictions.ensemble
if scoring_config.sample_source == "latent_ensemble_plus_noise":
if predictions.ensemble is None or predictions.noise_cov is None:
raise NotImplementedError(
"Sampling a data-predictive ensemble from a latent ensemble "
"requires both a latent predictive ensemble and observation "
"noise covariance."
)
return _sample_data_predictive_ensemble(
predictions.ensemble,
predictions.noise_cov,
sample_seed=scoring_config.sample_seed,
)
raise NotImplementedError(
f"Unsupported scoring sample source: {scoring_config.sample_source}."
)
Comment thread mkdocs.yml
Comment thread pyproject.toml
Comment on lines 33 to 37
"effectful>=0.2.0",
"cuthbert>=0.0.10",
"cuthbertlib>=0.0.10",
"cd-dynamax>=0.3.3",
"cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds",
"matplotlib>=3.10.7",
Comment thread dynestyx/inference/scoring.py Outdated
Comment on lines +248 to +252
pairwise = pred_ensemble[..., :, None, :] - pred_ensemble[..., None, :, :]
second_term = 0.5 * jnp.mean(
jnp.linalg.norm(pairwise, axis=-1) ** self.beta,
axis=(-2, -1),
)
Comment thread dynestyx/inference/filter_configs.py Outdated
Comment on lines +51 to +60
record_predicted_observations_mean (bool): Save the predicted
observation mean at each observation time, before conditioning on
that observation. Defaults to `False`, and scoring can be used
without automatically recording predictive summaries.
record_predicted_observations_cov (bool): Save the predicted
observation covariance at each observation time, before
conditioning on that observation. Defaults to `False`.
record_predicted_observations_ensemble (bool): Save the
predicted observation ensemble at each observation time
(ensemble-based filters only). Defaults to `False`.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 16 changed files in this pull request and generated 5 comments.

Comment on lines +389 to +401
assert predictions.mean is not None
if predictions.obs_cov is None:
raise NotImplementedError(
"Observation scoring requires predictive observation covariance."
)
return (
predictions.mean,
predictions.obs_cov,
_select_scoring_ensemble(
predictions,
scoring_config=scoring_config,
),
)
Comment thread mkdocs.yml
Comment thread mkdocs.yml
Comment thread pyproject.toml Outdated
"cuthbert>=0.0.10",
"cuthbertlib>=0.0.10",
"cd-dynamax>=0.3.3",
"cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds",
Comment on lines +127 to +133
t_len = _time_len_from_array(obs_times, plate_shapes)
state_shape = (*plate_shapes, dynamics.state_dim)
x_probe = jnp.zeros(state_shape, dtype=jnp.asarray(obs_times).dtype)
covs = []
for t_idx in range(t_len):
t = _slice_time_axis(obs_times, t_idx, plate_shapes)
u_t = (

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 18 changed files in this pull request and generated 2 comments.

Comment on lines +128 to +130
t_len = _time_len_from_array(obs_times, plate_shapes)
state_shape = (*plate_shapes, dynamics.state_dim)
x_probe = jnp.zeros(state_shape, dtype=jnp.asarray(obs_times).dtype)
Comment thread pyproject.toml
Comment on lines 33 to 37
"effectful>=0.2.0",
"cuthbert>=0.0.10",
"cuthbertlib>=0.0.10",
"cd-dynamax>=0.3.3",
"cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds",
"matplotlib>=3.10.7",

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 18 changed files in this pull request and generated 2 comments.

Comment thread dynestyx/inference/scoring.py Outdated
Comment on lines +306 to +312
sample_source: Strategy for obtaining predictive observation
ensembles when a rule needs samples. `"auto"` prefers a
backend-provided predictive observation ensemble, then falls back
to adding observation noise to a latent predictive ensemble, and
finally to Gaussian moments if the rule supports that path.
sample_seed: PRNG seed used when Dynestyx needs to synthesize
predictive ensembles from moments or latent ensembles plus noise.
Comment thread pyproject.toml
Comment on lines 33 to 37
"effectful>=0.2.0",
"cuthbert>=0.0.10",
"cuthbertlib>=0.0.10",
"cd-dynamax>=0.3.3",
"cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@0fd1bbf9dba5154af70d9dae9b925e572c023368",
"matplotlib>=3.10.7",

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 18 changed files in this pull request and generated 2 comments.

Comment on lines +529 to +538
posterior, predictions, score_arrays = enrich_continuous_filter_output(
posterior,
dynamics=dynamics,
filter_config=filter_config,
obs_times=obs_times,
obs_values=obs_values,
ctrl_values=ctrl_values,
scoring_config=scoring_config,
plate_shapes=plate_shapes,
)
Comment thread pyproject.toml
"cuthbert>=0.0.10",
"cuthbertlib>=0.0.10",
"cd-dynamax>=0.3.3",
"cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@0fd1bbf9dba5154af70d9dae9b925e572c023368",
@mattlevine22 mattlevine22 marked this pull request as ready for review June 16, 2026 21:05
@mattlevine22 mattlevine22 requested a review from LukeSnow0 June 16, 2026 21:08

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a closer look later, but we currently have dynestyx/diagnostics --- I think we should consider renaming dynestyx/diagnostics to dynestyx/evaluation and putting this there.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Agreed with moving out of inference and into dynestyx/diagnostics; I'm thinking we pull out ObservationScoringConfig and put it into a new inference/scoring_configs.py. Would be a short file, but mimics filter_configs and smoother_configs.
  2. I'm open to renaming to evaluation, but it does create solid churn across notebooks etc. Worth it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with moving out of inference and into dynestyx/diagnostics; I'm thinking we pull out ObservationScoringConfig and put it into a new inference/scoring_configs.py. Would be a short file, but mimics filter_configs and smoother_configs.

Great! I can picture it being not-so-short down the line, anyways.

I'm open to renaming to evaluation, but it does create solid churn across notebooks etc. Worth it?

I think so! But I'm biased. Maybe we should focus-group it...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sounds good! I'll change the name, I like evaluation a bit better too...just wanted to keep the PR from changing too many files. But if you like it let's do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants