Skip to content

State-Centered Temporal Processes#828

Open
cdc-mitzimorris wants to merge 79 commits into
mainfrom
mem_810_centered_parameterization
Open

State-Centered Temporal Processes#828
cdc-mitzimorris wants to merge 79 commits into
mainfrom
mem_810_centered_parameterization

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

Added state-centered parameterizations for all three temporal-process
classes in pyrenew.latent:

  • AR1 — stationary AR(1) on log-Rt levels
  • DifferencedAR1 — AR(1) on first differences of log-Rt (the production
    process)
  • RandomWalk — unconstrained drift on log-Rt

Each class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to
"innovation" to preserve current behavior. Setting "state" switches
the internal sampling from standardized increments to the latent state
path directly.

The state-centered variants are implemented via:

  • For RandomWalk: NumPyro's built-in dist.GaussianRandomWalk, shifted
    by the initial value.
  • For AR1 and DifferencedAR1: two new custom NumPyro Distribution
    subclasses (StateAR1, StateDifferencedAR1) in
    pyrenew/latent/state_centered_distributions.py. Both have vectorized
    log_prob using slice arithmetic (no scan during MCMC) and
    lax.scan-based sample (only called for prior/posterior predictive,
    not on the MCMC gradient path).

Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.

Code added

File Type Purpose
pyrenew/latent/state_centered_distributions.py new StateAR1, StateDifferencedAR1
pyrenew/latent/temporal_processes.py modified parameterization flag on all three classes; _prepare_initial_value helper
test/test_temporal_processes.py modified +31 unit tests (parameterization flag, state-centered shape/site/prior-equivalence)
test/test_helpers.py modified fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified Whitelist reparametrized_params (NumPyro upstream attribute name)

@codecov
Copy link
Copy Markdown

codecov Bot commented May 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 98.71%. Comparing base (0f223f5) to head (543135b).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #828      +/-   ##
==========================================
+ Coverage   98.61%   98.71%   +0.10%     
==========================================
  Files          55       56       +1     
  Lines        2023     2182     +159     
==========================================
+ Hits         1995     2154     +159     
  Misses         28       28              
Flag Coverage Δ
unittests 98.71% <100.00%> (+0.10%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

ran the benchmarks on my machine - here are the results:

time python -m benchmarks.suites.rt_params --candidate he --prior both --repeats 3
rt_params suite: 4 candidate(s) x 2 prior(s) x 3 repeat(s) = 24 fits
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3): 62.9s, divergences=0, min ESS/s=0.15
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3): 66.4s, divergences=0, min ESS/s=0.25
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3): 68.4s, divergences=0, min ESS/s=0.14
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 1/3): 63.5s, divergences=0, min ESS/s=5.79
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 2/3): 62.7s, divergences=0, min ESS/s=5.59
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 3/3): 63.4s, divergences=0, min ESS/s=6.61
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3): 68.9s, divergences=0, min ESS/s=1.05
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3): 69.3s, divergences=0, min ESS/s=0.12
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3): 70.7s, divergences=0, min ESS/s=0.47
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3): 17.9s, divergences=0, min ESS/s=28.92
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3): 16.6s, divergences=0, min ESS/s=30.93
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3): 16.8s, divergences=0, min ESS/s=32.74
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3): 79.7s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3): 79.4s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3): 80.2s, divergences=0, min ESS/s=0.03
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 1/3): 30.0s, divergences=0, min ESS/s=10.49
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 2/3): 31.4s, divergences=0, min ESS/s=9.56
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 3/3): 29.4s, divergences=0, min ESS/s=10.88
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3): 72.2s, divergences=0, min ESS/s=0.03
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3): 72.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3): 73.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3): 22.6s, divergences=0, min ESS/s=42.72
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3): 22.3s, divergences=0, min ESS/s=44.03
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3): 22.3s, divergences=0, min ESS/s=52.17

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  65.9         63.2        0.96x
ESS/s Rt (median)             0.748       27.329     36.53x *
ESS/s Rt (min)                0.183        5.997     32.78x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         9.91        0.99x
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.888        0.943      1.06x *
R-hat Rt (max)                1.275        1.006      0.79x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  69.7         17.1      0.25x *
ESS/s Rt (median)             1.856       98.404     53.02x *
ESS/s Rt (min)                0.546       30.864     56.49x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.17      0.72x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.896        0.925        1.03x
R-hat Rt (max)                1.150        1.005      0.87x *

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  79.8         30.2      0.38x *
ESS/s Rt (median)             0.078       72.302    928.36x *
ESS/s Rt (min)                0.032       10.311    322.37x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         8.06      0.81x *
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.901        0.920        1.02x
R-hat Rt (max)                2.350        1.014      0.43x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  73.0         22.4      0.31x *
ESS/s Rt (median)             0.098       75.878    772.58x *
ESS/s Rt (min)                0.038       46.309   1226.16x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.58      0.76x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.980        0.941        0.96x
R-hat Rt (max)                2.165        1.004      0.46x *

(* marks an improvement over innovation; ratios are state / innovation)

Wrote results to benchmarks/results

real	21m15.997s
user	80m23.152s
sys	0m7.630s

cdc-mitzimorris and others added 21 commits May 19, 2026 17:47
…e time 0) (#827)

* bug fix and unit tests

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* update unit test to match code

* revert changes, apply simpler fix

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
…v/PyRenew into mem_810_centered_parameterization
…v/PyRenew into mem_810_centered_parameterization
@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@dylanhmorris @sbidari ready for code review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants