Scoring rules for filters#259
Conversation
There was a problem hiding this comment.
Pull request overview
Adds first-class support for computing and (optionally) recording proper scoring rules for one-step-ahead predicted observation distributions produced by continuous-time CD-Dynamax Gaussian filters, including integration into the Filter handler, tests, and documentation updates.
Changes:
- Introduces
dynestyx.inference.scoring(score definitions +ObservationScoringConfig) anddynestyx.inference.observation_predictions(backend-to-canonical prediction enrichment + trace recording). - Wires scoring/enrichment through the continuous-time CD-Dynamax filter path and the
Filterhandler (including batched/plate execution). - Adds a comprehensive test suite for scoring/recording behavior and updates tutorials/API docs navigation to include the new scoring topic.
Reviewed changes
Copilot reviewed 15 out of 16 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_filter_scoring.py | New tests covering scoring rule outputs and trace recording behavior across continuous-time filter configs. |
| pyproject.toml | Switches cd-dynamax dependency to a Git branch reference needed for the new backend outputs. |
| mkdocs.yml | Adds tutorial + API nav entries for observation scoring and related missing-observations tutorials. |
| dynestyx/inference/scoring.py | New scoring-rule implementations and scoring configuration dataclass. |
| dynestyx/inference/observation_predictions.py | New canonicalization/enrichment layer to derive prediction summaries and scores from backend outputs and record them into the trace. |
| dynestyx/inference/integrations/cd_dynamax/continuous_filter.py | Plumbs scoring_config into the continuous-time CD-Dynamax filter run and records prediction/score sites. |
| dynestyx/inference/filters.py | Adds scoring_config to the Filter handler and enforces current support constraints; wires scoring through continuous-time paths and plate/batched execution. |
| dynestyx/inference/filter_configs.py | Adds record_predicted_observations_* fields to filter configs and includes them in recording kwargs. |
| dynestyx/inference/init.py | Exposes the new scoring module at the package level. |
| docs/tutorials/gentle_intro/11c_missing_observations_hmms.ipynb | Updates tutorial “Next” navigation to point to the new scoring tutorial. |
| docs/tutorials/gentle_intro/00_index.ipynb | Adds the Part 12 scoring tutorial to the gentle intro index. |
| docs/api_reference/public/inference/filters.md | Documents Filter scoring support and links to the Scoring page. |
| docs/api_reference/public/inference/filter_configs.md | Mentions predicted-observation recording fields and links to the Scoring page. |
| docs/api_reference/developer/inference/filters.md | Developer-facing note about scoring entry point and where backend translation lives. |
| docs/api_reference/developer/inference/filter_configs.md | Developer-facing note about predicted-observation recording fields and scoring linkage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if scoring_config.target == "latent_predictive": | ||
| if scoring_config.sample_source in {"auto", "backend_ensemble"}: | ||
| return predictions.ensemble | ||
| if scoring_config.sample_source == "latent_ensemble_plus_noise": | ||
| if predictions.ensemble is None or predictions.noise_cov is None: | ||
| raise NotImplementedError( | ||
| "Sampling a data-predictive ensemble from a latent ensemble " | ||
| "requires both a latent predictive ensemble and observation " | ||
| "noise covariance." | ||
| ) | ||
| return _sample_data_predictive_ensemble( | ||
| predictions.ensemble, | ||
| predictions.noise_cov, | ||
| sample_seed=scoring_config.sample_seed, | ||
| ) | ||
| raise NotImplementedError( | ||
| f"Unsupported scoring sample source: {scoring_config.sample_source}." | ||
| ) |
| "effectful>=0.2.0", | ||
| "cuthbert>=0.0.10", | ||
| "cuthbertlib>=0.0.10", | ||
| "cd-dynamax>=0.3.3", | ||
| "cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds", | ||
| "matplotlib>=3.10.7", |
| pairwise = pred_ensemble[..., :, None, :] - pred_ensemble[..., None, :, :] | ||
| second_term = 0.5 * jnp.mean( | ||
| jnp.linalg.norm(pairwise, axis=-1) ** self.beta, | ||
| axis=(-2, -1), | ||
| ) |
| record_predicted_observations_mean (bool): Save the predicted | ||
| observation mean at each observation time, before conditioning on | ||
| that observation. Defaults to `False`, and scoring can be used | ||
| without automatically recording predictive summaries. | ||
| record_predicted_observations_cov (bool): Save the predicted | ||
| observation covariance at each observation time, before | ||
| conditioning on that observation. Defaults to `False`. | ||
| record_predicted_observations_ensemble (bool): Save the | ||
| predicted observation ensemble at each observation time | ||
| (ensemble-based filters only). Defaults to `False`. |
…tions, address copilot comments
| assert predictions.mean is not None | ||
| if predictions.obs_cov is None: | ||
| raise NotImplementedError( | ||
| "Observation scoring requires predictive observation covariance." | ||
| ) | ||
| return ( | ||
| predictions.mean, | ||
| predictions.obs_cov, | ||
| _select_scoring_ensemble( | ||
| predictions, | ||
| scoring_config=scoring_config, | ||
| ), | ||
| ) |
| "cuthbert>=0.0.10", | ||
| "cuthbertlib>=0.0.10", | ||
| "cd-dynamax>=0.3.3", | ||
| "cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds", |
| t_len = _time_len_from_array(obs_times, plate_shapes) | ||
| state_shape = (*plate_shapes, dynamics.state_dim) | ||
| x_probe = jnp.zeros(state_shape, dtype=jnp.asarray(obs_times).dtype) | ||
| covs = [] | ||
| for t_idx in range(t_len): | ||
| t = _slice_time_axis(obs_times, t_idx, plate_shapes) | ||
| u_t = ( |
| t_len = _time_len_from_array(obs_times, plate_shapes) | ||
| state_shape = (*plate_shapes, dynamics.state_dim) | ||
| x_probe = jnp.zeros(state_shape, dtype=jnp.asarray(obs_times).dtype) |
| "effectful>=0.2.0", | ||
| "cuthbert>=0.0.10", | ||
| "cuthbertlib>=0.0.10", | ||
| "cd-dynamax>=0.3.3", | ||
| "cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@ml-return-ypreds", | ||
| "matplotlib>=3.10.7", |
| sample_source: Strategy for obtaining predictive observation | ||
| ensembles when a rule needs samples. `"auto"` prefers a | ||
| backend-provided predictive observation ensemble, then falls back | ||
| to adding observation noise to a latent predictive ensemble, and | ||
| finally to Gaussian moments if the rule supports that path. | ||
| sample_seed: PRNG seed used when Dynestyx needs to synthesize | ||
| predictive ensembles from moments or latent ensembles plus noise. |
| "effectful>=0.2.0", | ||
| "cuthbert>=0.0.10", | ||
| "cuthbertlib>=0.0.10", | ||
| "cd-dynamax>=0.3.3", | ||
| "cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@0fd1bbf9dba5154af70d9dae9b925e572c023368", | ||
| "matplotlib>=3.10.7", |
| posterior, predictions, score_arrays = enrich_continuous_filter_output( | ||
| posterior, | ||
| dynamics=dynamics, | ||
| filter_config=filter_config, | ||
| obs_times=obs_times, | ||
| obs_values=obs_values, | ||
| ctrl_values=ctrl_values, | ||
| scoring_config=scoring_config, | ||
| plate_shapes=plate_shapes, | ||
| ) |
| "cuthbert>=0.0.10", | ||
| "cuthbertlib>=0.0.10", | ||
| "cd-dynamax>=0.3.3", | ||
| "cd-dynamax @ git+https://github.com/hd-UQ/cd_dynamax.git@0fd1bbf9dba5154af70d9dae9b925e572c023368", |
There was a problem hiding this comment.
Will take a closer look later, but we currently have dynestyx/diagnostics --- I think we should consider renaming dynestyx/diagnostics to dynestyx/evaluation and putting this there.
There was a problem hiding this comment.
- Agreed with moving out of inference and into
dynestyx/diagnostics; I'm thinking we pull outObservationScoringConfigand put it into a newinference/scoring_configs.py. Would be a short file, but mimics filter_configs and smoother_configs. - I'm open to renaming to
evaluation, but it does create solid churn across notebooks etc. Worth it?
There was a problem hiding this comment.
Agreed with moving out of inference and into dynestyx/diagnostics; I'm thinking we pull out ObservationScoringConfig and put it into a new inference/scoring_configs.py. Would be a short file, but mimics filter_configs and smoother_configs.
Great! I can picture it being not-so-short down the line, anyways.
I'm open to renaming to evaluation, but it does create solid churn across notebooks etc. Worth it?
I think so! But I'm biased. Maybe we should focus-group it...
There was a problem hiding this comment.
Okay sounds good! I'll change the name, I like evaluation a bit better too...just wanted to keep the PR from changing too many files. But if you like it let's do it.
Summary
This PR adds first-class support for observation scoring on filter-predicted observation distributions, focused on the continuous-time
cd_dynamaxGaussian filter path. Users can now attach ascoring_configtoFilter(...)to compute proper scoring rules at each observation time, while keeping predictive-summary recording as a separate concern.Dependency
Relies on CD Dynamax PR.
What’s included
dynestyx/inference/scoring.py:ObservationScoringConfigGaussianLogProbScoreDawidSebastianiScoreObservationWiseCRPSScoreEnergyScoredynestyx/inference/observation_predictions.pythat:dynestyx/inference/filters.pydynestyx/inference/integrations/cd_dynamax/continuous_filter.pyBaseFilterConfigoptions indynestyx/inference/filter_configs.pyfor predicted-observation recording:record_predicted_observations_meanrecord_predicted_observations_covrecord_predicted_observations_ensembleBehavior / API notes
Filter(...)handler path, scoring only does work when the score arrays will actually be surfaced as NumPyro sites. Ifrecord_as_numpyro_sites=False, the handler path skips score computation entirely.ObservationScoringConfig.sample_seedis now the single scoring-level seed for any synthetic predictive sampling performed by Dynestyx during scoring, including:EnergyScoreGaussianLogProbScore,DawidSebastianiScore,ObservationWiseCRPSScore) do not depend on ensemble availability orsample_source.EnergyScore) usesample_sourceto choose between:Performance / implementation details
LinearGaussianObservationorGaussianObservationRis non-callablejax.lax.maprather than a Python loop.EnergyScoresupports both:vectorized_pairwise=FalseDocs and tutorials
docs/api_reference/public/inference/scoring.mddocs/api_reference/developer/inference/scoring.mddocs/tutorials/gentle_intro/12_observation_scoring_with_filters.ipynbTesting
tests/test_filter_scoring.pyEnergyScoreDependency / build notes
cd-dynamaxis pinned to the required upstream commit SHA inpyproject.tomlwhile waiting on the upstream release path.