From b705184f3a41733998c9d0e2a4014f1ac9ef11d4 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Wed, 10 Jun 2026 15:56:21 +0200 Subject: [PATCH 1/4] fix: pin tascCODA spike-and-slab theta to stop posterior collapse tascCODA returned zero credible effects because the global spike-and-slab mixing weight theta collapses to its Beta(1, d) prior (~0.01) under numpyro NUTS, which sends the selection threshold delta to infinity and zeroes out every node effect (issue #1015). This is the model's true marginal posterior, not a sampler bug: a single global theta gates a high-dimensional slab, so the low-theta funnel mouth carries almost all the marginal volume. The reference TFP implementation only avoids the collapse because its fixed identity-mass, short-trajectory HMC stays pinned near the theta=0.5 init -- i.e. the published results were computed with theta effectively fixed, never inferred. Hold theta fixed via numpyro.deterministic at pen_args["theta"] (default 0.5, the reference's operating point). samples["theta"], the delta credibility rule, arviz dims and param_names are all unchanged. On the tutorial data this recovers the expected credible effects (Immune, B cells, TA cells) and is stable across seeds and across theta in [0.34, 0.5]. Co-Authored-By: Claude Opus 4.8 --- pertpy/tools/_coda/_tasccoda.py | 16 +++++++++------- tests/tools/_coda/test_tasccoda.py | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index ce416006..fc0c5abe 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -149,7 +149,9 @@ def prepare( pen_args: Dictionary with penalty arguments. With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here. See the tascCODA paper for an explanation of these parameters. - Default: lambda_0 = 50, lambda_1 = 5, phi = 0. + `theta` is the global spike-and-slab mixing weight that sets the selection threshold; it is held fixed rather than sampled, since its posterior is weakly identified and collapses toward 0 under HMC/NUTS (issue #1015). + Lower `theta` raises the threshold and selects fewer effects; 0.5 matches the reference implementation's operating point. + Default: lambda_0 = 50, lambda_1 = 5, phi = 0, theta = 0.5. modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object. Returns: @@ -269,6 +271,9 @@ def prepare( pen_args["lambda_1"] = 5 if "phi" not in pen_args: pen_args["phi"] = 0 + # Global spike-and-slab mixing weight, held fixed because its posterior collapses under HMC/NUTS (issue #1015). + if "theta" not in pen_args: + pen_args["theta"] = 0.5 adata.uns["scCODA_params"]["sslasso_pen_args"] = pen_args @@ -360,7 +365,6 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat "b_raw_0": rng.normal(0.0, 1.0, beta_nobl_size), "a_1": np.ones(dtype=np.float64, shape=beta_nobl_size) * 1 / lambda_1, "b_raw_1": rng.normal(0.0, 1.0, beta_nobl_size), - "theta": np.ones(dtype=np.float64, shape=1) * 0.5, "alpha": rng.normal(0.0, 1.0, alpha_size), } @@ -399,9 +403,6 @@ def model( # type: ignore ref_index = jnp.sort(ref_index) num_ref_nodes = len(ref_index) - # Size of inferred parameter matrix - d = D * (T - num_ref_nodes) - # numpyro plates for all dimensions covariate_axis = npy.plate("covs", D, dim=-2) node_axis = npy.plate("ct", T, dim=-1) @@ -409,8 +410,9 @@ def model( # type: ignore cell_type_axis = npy.plate("ct", P, dim=-1) sample_axis = npy.plate("sample", N, dim=-2) - # Spike-and-slab LASSO effects - theta = npy.sample("theta", npd.Beta(concentration1=1.0, concentration0=d)) + # Global spike-and-slab mixing weight, held fixed at pen_args["theta"] rather than sampled. + # Its Beta(1, d) posterior is weakly identified and collapses toward 0 under HMC/NUTS, which inflates the selection threshold delta until no effect is credible (issue #1015). + theta = npy.deterministic("theta", jnp.array(sample_adata.uns["scCODA_params"]["sslasso_pen_args"]["theta"])) with covariate_axis, node_axis_nobl: a_0 = npy.sample("a_0", npd.Exponential((lambda_0**2) / 2)) diff --git a/tests/tools/_coda/test_tasccoda.py b/tests/tools/_coda/test_tasccoda.py index 3bf96ce3..ef981a46 100644 --- a/tests/tools/_coda/test_tasccoda.py +++ b/tests/tools/_coda/test_tasccoda.py @@ -56,6 +56,7 @@ def test_prepare(smillie_adata): assert "sample_counts" in mdata["coda"].obsm assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray) assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 8 + assert mdata["coda"].uns["scCODA_params"]["sslasso_pen_args"]["theta"] == 0.5 def test_load_invalid_type_raises_error(smillie_adata): @@ -80,3 +81,29 @@ def test_run_nuts(smillie_adata): assert "effect_df_Health[T.Non-inflamed]" in mdata["coda"].varm assert mdata["coda"].varm["effect_df_Health[T.Inflamed]"].shape == (51, 7) assert mdata["coda"].varm["effect_df_Health[T.Non-inflamed]"].shape == (51, 7) + + +def test_theta_fixed_not_collapsed(smillie_adata): + """Regression test for #1015. + + The global spike-and-slab mixing weight theta is held fixed rather than sampled. + A sampled theta collapses toward its Beta(1, d) prior (~0) under NUTS, which sends the selection threshold delta to infinity and makes every node non-credible. + Pinning theta keeps it at the configured value and the threshold finite. + """ + mdata = tasccoda.load( + smillie_adata, + type="sample_level", + levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + key_added="lineage", + add_level_name=True, + ) + mdata = tasccoda.prepare( + mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0, "theta": 0.3} + ) + tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + + theta_samples = np.asarray(mdata["coda"].uns["scCODA_params"]["mcmc"]["samples"]["theta"]) + assert np.allclose(theta_samples, 0.3) + + node_df = tasccoda.get_node_df(mdata) + assert np.isfinite(node_df["Delta"]).all() From b09ee694428e903dd040520fd3f669f5d80fcef1 Mon Sep 17 00:00:00 2001 From: LuisHeinzlmeier Date: Tue, 16 Jun 2026 11:40:27 +0200 Subject: [PATCH 2/4] rerun tasccoda --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 5c93d2c2..3029a518 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 5c93d2c2dc686da7c78d1c79f7b299d35e22edf2 +Subproject commit 3029a518de1d24668e7c95b8e33b2a6dc3143222 From e90e16818e1a3ebec321f94554241447ff65e131 Mon Sep 17 00:00:00 2001 From: LuisHeinzlmeier Date: Tue, 16 Jun 2026 11:43:28 +0200 Subject: [PATCH 3/4] Revert "rerun tasccoda" This reverts commit b09ee694428e903dd040520fd3f669f5d80fcef1. --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 3029a518..5c93d2c2 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 3029a518de1d24668e7c95b8e33b2a6dc3143222 +Subproject commit 5c93d2c2dc686da7c78d1c79f7b299d35e22edf2 From 34d6bfaea4cc45a2e2d80496b30dba319d7dfbd9 Mon Sep 17 00:00:00 2001 From: LuisHeinzlmeier Date: Tue, 16 Jun 2026 12:22:31 +0200 Subject: [PATCH 4/4] rerun tasccoda --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 5c93d2c2..5c4c0108 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 5c93d2c2dc686da7c78d1c79f7b299d35e22edf2 +Subproject commit 5c4c010820cbaaeac57e7eafe8de92ee2e0d5da1