diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 42d7ee0f..5c4c0108 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 42d7ee0fb8ede8167d65b577526af895e78a2436 +Subproject commit 5c4c010820cbaaeac57e7eafe8de92ee2e0d5da1 diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index ce416006..fc0c5abe 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -149,7 +149,9 @@ def prepare( pen_args: Dictionary with penalty arguments. With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here. See the tascCODA paper for an explanation of these parameters. - Default: lambda_0 = 50, lambda_1 = 5, phi = 0. + `theta` is the global spike-and-slab mixing weight that sets the selection threshold; it is held fixed rather than sampled, since its posterior is weakly identified and collapses toward 0 under HMC/NUTS (issue #1015). + Lower `theta` raises the threshold and selects fewer effects; 0.5 matches the reference implementation's operating point. + Default: lambda_0 = 50, lambda_1 = 5, phi = 0, theta = 0.5. modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object. Returns: @@ -269,6 +271,9 @@ def prepare( pen_args["lambda_1"] = 5 if "phi" not in pen_args: pen_args["phi"] = 0 + # Global spike-and-slab mixing weight, held fixed because its posterior collapses under HMC/NUTS (issue #1015). + if "theta" not in pen_args: + pen_args["theta"] = 0.5 adata.uns["scCODA_params"]["sslasso_pen_args"] = pen_args @@ -360,7 +365,6 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat "b_raw_0": rng.normal(0.0, 1.0, beta_nobl_size), "a_1": np.ones(dtype=np.float64, shape=beta_nobl_size) * 1 / lambda_1, "b_raw_1": rng.normal(0.0, 1.0, beta_nobl_size), - "theta": np.ones(dtype=np.float64, shape=1) * 0.5, "alpha": rng.normal(0.0, 1.0, alpha_size), } @@ -399,9 +403,6 @@ def model( # type: ignore ref_index = jnp.sort(ref_index) num_ref_nodes = len(ref_index) - # Size of inferred parameter matrix - d = D * (T - num_ref_nodes) - # numpyro plates for all dimensions covariate_axis = npy.plate("covs", D, dim=-2) node_axis = npy.plate("ct", T, dim=-1) @@ -409,8 +410,9 @@ def model( # type: ignore cell_type_axis = npy.plate("ct", P, dim=-1) sample_axis = npy.plate("sample", N, dim=-2) - # Spike-and-slab LASSO effects - theta = npy.sample("theta", npd.Beta(concentration1=1.0, concentration0=d)) + # Global spike-and-slab mixing weight, held fixed at pen_args["theta"] rather than sampled. + # Its Beta(1, d) posterior is weakly identified and collapses toward 0 under HMC/NUTS, which inflates the selection threshold delta until no effect is credible (issue #1015). + theta = npy.deterministic("theta", jnp.array(sample_adata.uns["scCODA_params"]["sslasso_pen_args"]["theta"])) with covariate_axis, node_axis_nobl: a_0 = npy.sample("a_0", npd.Exponential((lambda_0**2) / 2)) diff --git a/tests/tools/_coda/test_tasccoda.py b/tests/tools/_coda/test_tasccoda.py index 3bf96ce3..ef981a46 100644 --- a/tests/tools/_coda/test_tasccoda.py +++ b/tests/tools/_coda/test_tasccoda.py @@ -56,6 +56,7 @@ def test_prepare(smillie_adata): assert "sample_counts" in mdata["coda"].obsm assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray) assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 8 + assert mdata["coda"].uns["scCODA_params"]["sslasso_pen_args"]["theta"] == 0.5 def test_load_invalid_type_raises_error(smillie_adata): @@ -80,3 +81,29 @@ def test_run_nuts(smillie_adata): assert "effect_df_Health[T.Non-inflamed]" in mdata["coda"].varm assert mdata["coda"].varm["effect_df_Health[T.Inflamed]"].shape == (51, 7) assert mdata["coda"].varm["effect_df_Health[T.Non-inflamed]"].shape == (51, 7) + + +def test_theta_fixed_not_collapsed(smillie_adata): + """Regression test for #1015. + + The global spike-and-slab mixing weight theta is held fixed rather than sampled. + A sampled theta collapses toward its Beta(1, d) prior (~0) under NUTS, which sends the selection threshold delta to infinity and makes every node non-credible. + Pinning theta keeps it at the configured value and the threshold finite. + """ + mdata = tasccoda.load( + smillie_adata, + type="sample_level", + levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + key_added="lineage", + add_level_name=True, + ) + mdata = tasccoda.prepare( + mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0, "theta": 0.3} + ) + tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + + theta_samples = np.asarray(mdata["coda"].uns["scCODA_params"]["mcmc"]["samples"]["theta"]) + assert np.allclose(theta_samples, 0.3) + + node_df = tasccoda.get_node_df(mdata) + assert np.isfinite(node_df["Delta"]).all()