Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
Submodule notebooks updated 1 files
+295 −287 tasccoda.ipynb
16 changes: 9 additions & 7 deletions pertpy/tools/_coda/_tasccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def prepare(
pen_args: Dictionary with penalty arguments.
With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here.
See the tascCODA paper for an explanation of these parameters.
Default: lambda_0 = 50, lambda_1 = 5, phi = 0.
`theta` is the global spike-and-slab mixing weight that sets the selection threshold; it is held fixed rather than sampled, since its posterior is weakly identified and collapses toward 0 under HMC/NUTS (issue #1015).
Lower `theta` raises the threshold and selects fewer effects; 0.5 matches the reference implementation's operating point.
Default: lambda_0 = 50, lambda_1 = 5, phi = 0, theta = 0.5.
modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object.

Returns:
Expand Down Expand Up @@ -269,6 +271,9 @@ def prepare(
pen_args["lambda_1"] = 5
if "phi" not in pen_args:
pen_args["phi"] = 0
# Global spike-and-slab mixing weight, held fixed because its posterior collapses under HMC/NUTS (issue #1015).
if "theta" not in pen_args:
pen_args["theta"] = 0.5

adata.uns["scCODA_params"]["sslasso_pen_args"] = pen_args

Expand Down Expand Up @@ -360,7 +365,6 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat
"b_raw_0": rng.normal(0.0, 1.0, beta_nobl_size),
"a_1": np.ones(dtype=np.float64, shape=beta_nobl_size) * 1 / lambda_1,
"b_raw_1": rng.normal(0.0, 1.0, beta_nobl_size),
"theta": np.ones(dtype=np.float64, shape=1) * 0.5,
"alpha": rng.normal(0.0, 1.0, alpha_size),
}

Expand Down Expand Up @@ -399,18 +403,16 @@ def model( # type: ignore
ref_index = jnp.sort(ref_index)
num_ref_nodes = len(ref_index)

# Size of inferred parameter matrix
d = D * (T - num_ref_nodes)

# numpyro plates for all dimensions
covariate_axis = npy.plate("covs", D, dim=-2)
node_axis = npy.plate("ct", T, dim=-1)
node_axis_nobl = npy.plate("ctnb", T - num_ref_nodes, dim=-1)
cell_type_axis = npy.plate("ct", P, dim=-1)
sample_axis = npy.plate("sample", N, dim=-2)

# Spike-and-slab LASSO effects
theta = npy.sample("theta", npd.Beta(concentration1=1.0, concentration0=d))
# Global spike-and-slab mixing weight, held fixed at pen_args["theta"] rather than sampled.
# Its Beta(1, d) posterior is weakly identified and collapses toward 0 under HMC/NUTS, which inflates the selection threshold delta until no effect is credible (issue #1015).
theta = npy.deterministic("theta", jnp.array(sample_adata.uns["scCODA_params"]["sslasso_pen_args"]["theta"]))

with covariate_axis, node_axis_nobl:
a_0 = npy.sample("a_0", npd.Exponential((lambda_0**2) / 2))
Expand Down
27 changes: 27 additions & 0 deletions tests/tools/_coda/test_tasccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_prepare(smillie_adata):
assert "sample_counts" in mdata["coda"].obsm
assert isinstance(mdata["coda"].obsm["sample_counts"], np.ndarray)
assert np.sum(mdata["coda"].obsm["covariate_matrix"]) == 8
assert mdata["coda"].uns["scCODA_params"]["sslasso_pen_args"]["theta"] == 0.5


def test_load_invalid_type_raises_error(smillie_adata):
Expand All @@ -80,3 +81,29 @@ def test_run_nuts(smillie_adata):
assert "effect_df_Health[T.Non-inflamed]" in mdata["coda"].varm
assert mdata["coda"].varm["effect_df_Health[T.Inflamed]"].shape == (51, 7)
assert mdata["coda"].varm["effect_df_Health[T.Non-inflamed]"].shape == (51, 7)


def test_theta_fixed_not_collapsed(smillie_adata):
"""Regression test for #1015.

The global spike-and-slab mixing weight theta is held fixed rather than sampled.
A sampled theta collapses toward its Beta(1, d) prior (~0) under NUTS, which sends the selection threshold delta to infinity and makes every node non-credible.
Pinning theta keeps it at the configured value and the threshold finite.
"""
mdata = tasccoda.load(
smillie_adata,
type="sample_level",
levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
key_added="lineage",
add_level_name=True,
)
mdata = tasccoda.prepare(
mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0, "theta": 0.3}
)
tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)

theta_samples = np.asarray(mdata["coda"].uns["scCODA_params"]["mcmc"]["samples"]["theta"])
assert np.allclose(theta_samples, 0.3)

node_df = tasccoda.get_node_df(mdata)
assert np.isfinite(node_df["Delta"]).all()
Loading