diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 3fc66c52..de4f8a89 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -34,7 +34,7 @@ jobs: "enrichment.ipynb", "guide_rna_assignment.ipynb", "milo.ipynb", - "mixscape.ipynb", + "perturbation_efficacy.ipynb", "sccoda.ipynb", # "perturbation_space.ipynb", seems to run OOM as of Jax implementation # "tasccoda.ipynb", a pain to get running because of the required QT dependency. The QT action leads to a dead kernel diff --git a/docs/_static/tutorials/mixscape.png b/docs/_static/tutorials/perturbation_efficacy.png similarity index 100% rename from docs/_static/tutorials/mixscape.png rename to docs/_static/tutorials/perturbation_efficacy.png diff --git a/docs/api/tools_index.md b/docs/api/tools_index.md index 9d83abf1..792ac471 100644 --- a/docs/api/tools_index.md +++ b/docs/api/tools_index.md @@ -84,7 +84,34 @@ ms.lda(adata=mdata["rna"], labels="gene_target", layer="X_pert", control="NT") ms.plot_lda(adata=mdata["rna"], control="NT") ``` -See [mixscape tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/mixscape.html). +See [perturbation efficacy tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/perturbation_efficacy.html). + +### Perturbation scoring - Mixscale + +[Mixscale](https://doi.org/10.1038/s41556-025-01622-z) extends Mixscape with a continuous perturbation score instead of a binary perturbed/non-perturbed call {cite}`Jiang2025`. +Where Mixscape assigns each cell a discrete KO/NP label, Mixscale quantifies how strongly each cell responded to its perturbation. +This is useful for CRISPRi/CRISPRa screens where cells show a gradient of responses, and as input to downstream weighted differential expression and pathway analyses. +Mixscale shares the perturbation signature and differential expression steps with Mixscape, so `perturbation_signature` is available on both classes. + +```{eval-rst} +.. autosummary:: + :toctree: tools + + tools.Mixscale +``` + +Example implementation: + +```python +import pertpy as pt + +mdata = pt.dt.papalexi_2021() +ms = pt.tl.Mixscale() +ms.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") +ms.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert") +``` + +See [perturbation efficacy tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/perturbation_efficacy.html). ## Compositional analysis diff --git a/docs/conf.py b/docs/conf.py index eba90be8..e1e29d40 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -151,7 +151,7 @@ sphinx_gallery_conf = {"nested_sections=": False} nbsphinx_thumbnails = { "tutorials/notebooks/guide_rna_assignment": "_static/tutorials/guide_rna_assignment.png", - "tutorials/notebooks/mixscape": "_static/tutorials/mixscape.png", + "tutorials/notebooks/perturbation_efficacy": "_static/tutorials/perturbation_efficacy.png", "tutorials/notebooks/augur": "_static/tutorials/augur.png", "tutorials/notebooks/sccoda": "_static/tutorials/sccoda.png", "tutorials/notebooks/sccoda_extended": "_static/tutorials/sccoda_extended.png", diff --git a/docs/references.bib b/docs/references.bib index bc2a8b78..0fba5289 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -73,6 +73,18 @@ @article{Papalexi2021 issn = {1546-1718} } +@article{Jiang2025, + author = {Jiang, Longda and Dalgarno, Carol and Papalexi, Efthymia and others}, + title = {Systematic reconstruction of molecular pathway signatures using scalable single-cell perturbation screens}, + journal = {Nature Cell Biology}, + year = {2025}, + volume = {27}, + pages = {505--517}, + doi = {10.1038/s41556-025-01622-z}, + url = {https://doi.org/10.1038/s41556-025-01622-z}, + issn = {1476-4679} +} + @article{Dann2022, author = {Dann, Emma and Henderson, Neil C. and Teichmann, Sarah A. and Morgan, Michael D. and Marioni, John C.}, title = {Differential abundance testing on single-cell data using k-nearest neighbor graphs}, diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index d2ab4bfb..5c93d2c2 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit d2ab4bfbd1bcc18eaf977766117ba5fff62804f8 +Subproject commit 5c93d2c2dc686da7c78d1c79f7b299d35e22edf2 diff --git a/docs/tutorials/tools.md b/docs/tutorials/tools.md index 2dcef229..4675562b 100644 --- a/docs/tutorials/tools.md +++ b/docs/tutorials/tools.md @@ -13,7 +13,7 @@ ```{eval-rst} .. nbgallery:: - notebooks/mixscape + notebooks/perturbation_efficacy ``` ## Compositional analysis diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 06a44cd7..2a602a5c 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -8,7 +8,8 @@ from pertpy.tools._distances._distances import Distance from pertpy.tools._enrichment import Enrichment from pertpy.tools._milo import Milo -from pertpy.tools._mixscape import Mixscape +from pertpy.tools._perturbation_efficacy._mixscale import Mixscale +from pertpy.tools._perturbation_efficacy._mixscape import Mixscape from pertpy.tools._perturbation_space._clustering import ClusteringSpace from pertpy.tools._perturbation_space._comparison import PerturbationComparison from pertpy.tools._perturbation_space._discriminator_classifiers import ( @@ -70,6 +71,7 @@ def __dir__(): "Enrichment", "Milo", "Mixscape", + "Mixscale", "ClusteringSpace", "PerturbationComparison", "LRClassifierSpace", diff --git a/pertpy/tools/_perturbation_efficacy/__init__.py b/pertpy/tools/_perturbation_efficacy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pertpy/tools/_perturbation_efficacy/_base.py b/pertpy/tools/_perturbation_efficacy/_base.py new file mode 100644 index 00000000..730b9f7c --- /dev/null +++ b/pertpy/tools/_perturbation_efficacy/_base.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Literal + +import numpy as np +import scanpy as sc +from fast_array_utils.stats import mean, mean_var +from pandas.errors import PerformanceWarning +from scanpy.tools._utils import _choose_representation +from scipy.sparse import csr_array, csr_matrix, issparse, sparray + +if TYPE_CHECKING: + from anndata import AnnData + + +class PerturbationEfficacyAnalyzer: + """Shared substrate for the perturbation efficacy analysis tools. + + It holds the steps that both the binary :class:`~pertpy.tools.Mixscape` classification and the continuous :class:`~pertpy.tools.Mixscale` scoring build on. + These are computing the perturbation signature and detecting the differentially expressed marker genes per perturbation. + """ + + def __init__(self): + pass + + def perturbation_signature( + self, + adata: AnnData, + pert_key: str, + control: str, + *, + ref_selection_mode: Literal["nn", "split_by"] = "nn", + split_by: str | None = None, + n_neighbors: int = 20, + use_rep: str | None = None, + n_dims: int | None = 15, + n_pcs: int | None = None, + batch_size: int | None = None, + copy: bool = False, + **kwargs, + ): + """Calculate perturbation signature. + + The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged mRNA expression profile of the control cells (selected according to `ref_selection_mode`). + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. + Note that in the original implementation, the perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same. + + Args: + adata: The annotated data object. + pert_key: The column of `.obs` with perturbation categories, should also contain `control`. + control: Name of the control category from the `pert_key` column. + ref_selection_mode: Method to select reference cells for the perturbation signature calculation. + If `nn`, the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. + If `split_by`, the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. + split_by: Provide the column `.obs` if multiple biological replicates exist to calculate the perturbation signature for every replicate separately. + n_neighbors: Number of neighbors from the control to use for the perturbation signature. + use_rep: Use the indicated representation. `'X'` or any key for `.obsm` is valid. + If `None`, the representation is chosen automatically: + For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. + If 'X_pca' is not present, it's computed with default parameters. + n_dims: Number of dimensions to use from the representation to calculate the perturbation signature. + If `None`, use all dimensions. + n_pcs: If PCA representation is used, the number of principal components to compute. + If `n_pcs==0` use `.X` if `use_rep is None`. + batch_size: Size of batch to calculate the perturbation signature. + If 'None', the perturbation signature is calcuated in the full mode, requiring more memory. + The batched mode is very inefficient for sparse data. + copy: Determines whether a copy of the `adata` is returned. + **kwargs: Additional arguments for the `NNDescent` class from `pynndescent`. + + Returns: + If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`. + Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`. + + Examples: + Calcutate perturbation signature for each cell in the dataset: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + """ + if ref_selection_mode not in ["nn", "split_by"]: + raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.") + if ref_selection_mode == "split_by" and split_by is None: + raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.") + + if copy: + adata = adata.copy() + + # pynndescent and the LIL workflow below only support legacy scipy sparse matrices, so a sparse array + # input is computed on as a csr_matrix and converted back to a sparse array at the end. + input_is_sparray = isinstance(adata.X, sparray) + X = csr_matrix(adata.X) if input_is_sparray else adata.X + adata.layers["X_pert"] = X.copy() + + # Work with LIL for efficient indexing but don't store it in AnnData as LIL is not supported anymore + X_pert_lil = adata.layers["X_pert"].tolil() if issparse(adata.layers["X_pert"]) else adata.layers["X_pert"] + + control_mask = adata.obs[pert_key] == control + + if ref_selection_mode == "split_by": + for split in adata.obs[split_by].unique(): + split_mask = adata.obs[split_by] == split + control_mask_group = control_mask & split_mask + control_mean_expr = mean(X[control_mask_group], axis=0) + X_pert_lil[split_mask] = ( + np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - X_pert_lil[split_mask] + ) + else: + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + else: + split_obs = adata.obs[split_by] + split_masks = [split_obs == cat for cat in split_obs.unique()] + + representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if isinstance(representation, sparray): + representation = csr_matrix(representation) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] + + from pynndescent import NNDescent + + for split_mask in split_masks: + control_mask_split = control_mask & split_mask + R_split = representation[split_mask] + R_control = representation[np.asarray(control_mask_split)] + eps = kwargs.pop("epsilon", 0.1) + nn_index = NNDescent(R_control, **kwargs) + indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) + X_control = np.expm1(X[np.asarray(control_mask_split)]) + n_split = split_mask.sum() + n_control = X_control.shape[0] + + if batch_size is None: + col_indices = np.ravel(indices) + row_indices = np.repeat(np.arange(n_split), n_neighbors) + neigh_matrix = csr_matrix( + (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), + shape=(n_split, n_control), + ) + neigh_matrix /= n_neighbors + X_pert_lil[np.asarray(split_mask)] = ( + sc.pp.log1p(neigh_matrix @ X_control) - X_pert_lil[np.asarray(split_mask)] + ) + else: + split_indices = np.where(split_mask)[0] + for i in range(0, n_split, batch_size): + size = min(i + batch_size, n_split) + select = slice(i, size) + batch = np.ravel(indices[select]) + split_batch = split_indices[select] + size = size - i + means_batch = X_control[batch] + batch_reshaped = means_batch.reshape(size, n_neighbors, -1) + means_batch, _ = mean_var(batch_reshaped, axis=1) + X_pert_lil[split_batch] = np.log1p(means_batch) - X_pert_lil[split_batch] + + if issparse(X_pert_lil): + x_pert = X_pert_lil.tocsr() + adata.layers["X_pert"] = csr_array(x_pert) if input_is_sparray else x_pert + else: + adata.layers["X_pert"] = X_pert_lil + + if copy: + return adata + + def _get_perturbation_markers( + self, + adata: AnnData, + *, + split_masks: list[np.ndarray], + categories: list[str], + pert_key: str, + control: str, + layer: str, + pval_cutoff: float, + min_de_genes: float, + logfc_threshold: float, + test_method: str, + ) -> dict[tuple, np.ndarray]: + """Determine gene sets across all splits/groups through differential gene expression. + + Args: + adata: :class:`~anndata.AnnData` object + split_masks: List of boolean masks for each split/group. + categories: List of split/group names. + pert_key: The column of `.obs` with target gene labels. + control: Control category from the `labels` column. + layer: Key from adata.layers whose value will be used to compare gene expression. + pval_cutoff: P-value cut-off for selection of significantly DE genes. + min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. + logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. + test_method: Method to use for differential expression testing. + + Returns: + Set of column indices. + """ + perturbation_markers: dict[tuple, np.ndarray] = {} # type: ignore + for split, split_mask in enumerate(split_masks): + category = categories[split] + # get gene sets for each split + gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control])) + adata_split = adata[split_mask].copy() + # find top DE genes between cells with targeting and non-targeting gRNAs + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + warnings.simplefilter("ignore", PerformanceWarning) + sc.tl.rank_genes_groups( + adata_split, + layer=layer, + groupby=pert_key, + groups=gene_targets, + reference=control, + method=test_method, + use_raw=False, + ) + # get DE genes for each target gene + for gene in gene_targets: + logfc_threshold_mask = ( + np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold + ) + de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask] + pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask] + de_genes = de_genes[pvals_adj < pval_cutoff] + if len(de_genes) < min_de_genes: + de_genes = np.array([]) + perturbation_markers[(category, gene)] = de_genes + + return perturbation_markers diff --git a/pertpy/tools/_perturbation_efficacy/_mixscale.py b/pertpy/tools/_perturbation_efficacy/_mixscale.py new file mode 100644 index 00000000..a632d352 --- /dev/null +++ b/pertpy/tools/_perturbation_efficacy/_mixscale.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import warnings +from functools import singledispatch +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import scanpy as sc +from pandas.errors import PerformanceWarning +from scipy.sparse import issparse, sparray + +from pertpy.tools._perturbation_efficacy._base import PerturbationEfficacyAnalyzer + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from anndata import AnnData + + +@singledispatch +def _subset_column_mean(matrix: np.ndarray, row_mask: np.ndarray) -> np.ndarray: + """Per-column mean over the rows selected by `row_mask`.""" + return matrix[row_mask].mean(axis=0) + + +@singledispatch +def _project(matrix: np.ndarray, direction: np.ndarray) -> np.ndarray: + """Project each row onto `direction` (the unnormalized scalar projection numerator).""" + return matrix @ direction + + +@singledispatch +def _leave_one_out_numerators(matrix: np.ndarray, numerator: np.ndarray, direction: np.ndarray) -> np.ndarray: + """For each gene (column) j, the projection numerator recomputed with gene j left out.""" + return numerator[:, None] - matrix * direction[None, :] + + +@_subset_column_mean.register(sparray) +def _(matrix, row_mask: np.ndarray) -> np.ndarray: + return np.asarray(matrix[row_mask].mean(axis=0)).ravel() + + +@_project.register(sparray) +def _(matrix, direction: np.ndarray) -> np.ndarray: + return np.asarray(matrix @ direction).ravel() + + +@_leave_one_out_numerators.register(sparray) +def _(matrix, numerator: np.ndarray, direction: np.ndarray) -> np.ndarray: + return numerator[:, None] - matrix.multiply(direction[None, :]).toarray() + + +class Mixscale(PerturbationEfficacyAnalyzer): + """Continuous perturbation scoring for pooled CRISPR screens. + + Where :class:`~pertpy.tools.Mixscape` assigns each cell a binary perturbed/non-perturbed label, Mixscale assigns a continuous perturbation score that reflects how strongly each cell responded. + This is useful for CRISPRi/CRISPRa screens where cells show a gradient of responses rather than a clean knockout, and as input to downstream weighted differential expression and pathway analyses. + + The method is described in Jiang, Dalgarno et al., "Systematic reconstruction of molecular pathway signatures using scalable single-cell perturbation screens", Nature Cell Biology (2025) {cite}`Jiang2025`. + It reproduces the reference implementation from the satijalab/Mixscale R package (https://github.com/satijalab/Mixscale). + """ + + def mixscale( + self, + adata: AnnData, + pert_key: str, + control: str, + *, + new_class_name: str = "mixscale_score", + layer: str | None = None, + min_de_genes: int = 5, + max_de_genes: int = 100, + logfc_threshold: float = 0.25, + de_layer: str | None = None, + test_method: str = "wilcoxon", + scale: bool = True, + split_by: str | None = None, + pval_cutoff: float = 5e-2, + fine_mode: bool = False, + fine_mode_labels: str = "guide_id", + de_genes_by_target: Mapping[str, Sequence[str]] | None = None, + harmonize: bool = False, + harmonize_min_proportion: float = 0.1, + random_state: int = 0, + copy: bool = False, + ): + """Calculate a continuous perturbation score per cell with the Mixscale method. + + For every target gene the large-effect differentially expressed (DE) genes between its cells and the control cells are determined. + The perturbation direction vector (mean perturbed minus mean control over those genes) is computed, and each cell's perturbation signature is projected onto that vector. + The per-cell projection is then standardized against the control distribution. + DE genes are detected on all cells pooled, while the direction vector and standardization are computed within each `split_by` group. + The automatic DE detection relies on :func:`scanpy.tl.rank_genes_groups` and may select a slightly different gene set than the reference implementation; pass `de_genes_by_target` to score against a fixed gene set instead. + + Run :meth:`perturbation_signature` first to populate `.layers["X_pert"]`. + + Args: + adata: The annotated data object. + pert_key: The column of `.obs` with target gene labels. + control: Control category from the `pert_key` column. + new_class_name: Name of the score column to be stored in `.obs`. + layer: Key from `adata.layers` whose value is used for scoring. If `None`, `.layers["X_pert"]` is used. + min_de_genes: Required number of DE genes for scoring a perturbation. Perturbations with fewer DE genes are not scored and their cells receive the fallback score of 1. + max_de_genes: Maximum number of top DE genes (by adjusted p-value) used for scoring. + logfc_threshold: Minimum absolute log fold-change for a gene to be considered a large-effect DE gene. + de_layer: Layer used for the DE test. If `None`, `adata.X` is used. + test_method: Method passed to :func:`scanpy.tl.rank_genes_groups` for DE testing. + scale: Whether to z-score each gene's perturbation signature (mean-centered and scaled to unit variance, then clipped at 10) before scoring. + split_by: `.obs` column with a condition/cell-type annotation. The direction vector and standardization are computed separately within each group, while DE genes are still detected on all cells. + pval_cutoff: Adjusted p-value cut-off for selecting significant DE genes. + fine_mode: If `True`, DE genes are computed per gRNA (`fine_mode_labels`) and pooled per target gene, rather than once per target gene. + fine_mode_labels: `.obs` column with gRNA identifiers, used when `fine_mode` is `True`. + de_genes_by_target: Optional mapping from target gene to a user-defined list of DE genes. When given, the DE test is skipped entirely and targets absent from the mapping are not scored. + harmonize: If `True` and `split_by` resolves to more than one group, control cells are subsampled so that their per-group composition matches the perturbed cells before the DE test. + harmonize_min_proportion: Minimum fraction of control cells that must be retained during harmonization. Groups are dropped until the constraint is met. + random_state: Seed for the control subsampling performed during harmonization. + copy: Determines whether a copy of `adata` is returned. + + Returns: + If `copy=True`, returns the copy of `adata` with the scores in `.obs`. + Otherwise, writes the scores directly to `.obs` of the provided `adata`. + + The following fields are added: + + - `adata.obs[new_class_name]`: Continuous perturbation score per cell. Control cells receive 0, cells of perturbations that could not be scored receive 1, and all other cells receive the projection standardized against the control distribution. Higher values indicate a stronger response. + - `adata.uns["mixscale"]`: Per target gene and split, a :class:`~pandas.DataFrame` with the raw projection (`pvec`), the cell labels, and the leave-one-out projections (one column per DE gene). + - `adata.uns["mixscale_de_genes"]`: The DE genes used for each target gene. + + Examples: + Compute continuous perturbation scores: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ms = pt.tl.Mixscale() + >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert") + """ + if copy: + adata = adata.copy() + + if layer is not None: + X = adata.layers[layer] + else: + try: + X = adata.layers["X_pert"] + except KeyError: + raise KeyError( + "No 'X_pert' found in .layers! Please run perturbation_signature first to calculate the perturbation signature!" + ) from None + + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + categories = ["all"] + else: + split_obs = adata.obs[split_by] + categories = list(split_obs.unique()) + split_masks = [(split_obs == category).to_numpy() for category in categories] + + # DE genes are detected on all cells pooled (the direction vector and standardization are per split). + perturbation_markers = self._get_mixscale_markers( + adata, + pert_key=pert_key, + control=control, + de_layer=de_layer, + test_method=test_method, + logfc_threshold=logfc_threshold, + pval_cutoff=pval_cutoff, + min_de_genes=min_de_genes, + max_de_genes=max_de_genes, + fine_mode=fine_mode, + fine_mode_labels=fine_mode_labels, + de_genes_by_target=de_genes_by_target, + harmonize=harmonize, + harmonize_min_proportion=harmonize_min_proportion, + split_by=split_by, + random_state=random_state, + ) + + var_loc = {name: i for i, name in enumerate(adata.var_names)} + pert_values = adata.obs[pert_key].to_numpy() + gv_list: dict[str, dict] = {} + scores = np.zeros(adata.n_obs) + scored = np.zeros(adata.n_obs, dtype=bool) + + for split, split_mask in enumerate(split_masks): + category = categories[split] + nt_mask = (pert_values == control) & split_mask + genes_in_split = set(pert_values[split_mask]).difference([control]) + for gene in genes_in_split: + de_genes = [g for g in perturbation_markers.get(gene, ()) if g in var_loc] + if len(de_genes) == 0: + continue + + guide_mask = (pert_values == gene) & split_mask + all_mask = guide_mask | nt_mask + guide_in_dat = guide_mask[all_mask] + nt_in_dat = nt_mask[all_mask] + if not guide_in_dat.any() or not nt_in_dat.any(): + continue + + de_indices = [var_loc[g] for g in de_genes] + dat = X[all_mask][:, de_indices].astype(np.float64) + if scale: + dat = self._scale_features(dat) + + vec = _subset_column_mean(dat, guide_in_dat) - _subset_column_mean(dat, nt_in_dat) + vec_norm_sq = float(vec @ vec) + if not vec_norm_sq > 0: + continue + + numerator = _project(dat, vec) + pvec = numerator / vec_norm_sq + vec_sq = vec * vec + with np.errstate(divide="ignore", invalid="ignore"): + loo = _leave_one_out_numerators(dat, numerator, vec) / (vec_norm_sq - vec_sq[None, :]) + + nt_pvec = pvec[nt_in_dat] + std_nt = nt_pvec.std(ddof=1) + if std_nt == 0 or np.isnan(std_nt): + std_nt = 1.0 + guide_positions = np.flatnonzero(guide_mask) + scores[guide_positions] = (pvec[guide_in_dat] - nt_pvec.mean()) / std_nt + scored[guide_positions] = True + + all_names = adata.obs_names[all_mask] + gv = pd.DataFrame(index=all_names) + gv["pvec"] = pvec + gv[pert_key] = control + gv.loc[all_names[guide_in_dat], pert_key] = gene + gv = pd.concat([gv, pd.DataFrame(loo, index=all_names, columns=de_genes)], axis=1) + gv_list.setdefault(gene, {})[category] = gv + + scores[(~scored) & (pert_values != control)] = 1.0 + + adata.obs[new_class_name] = scores + adata.uns["mixscale"] = gv_list + adata.uns["mixscale_de_genes"] = {gene: np.asarray(genes) for gene, genes in perturbation_markers.items()} + + if copy: + return adata + + @staticmethod + def _scale_features(dat, *, scale_max: float = 10.0) -> np.ndarray: + """Z-score each gene (column) and clip, mirroring Seurat's `ScaleData`. + + Zero-centering necessarily densifies the (cells x DE-gene) submatrix, exactly as Seurat's `ScaleData` and :meth:`~pertpy.tools.Mixscape.mixscape` do. + """ + dat = dat.toarray() if issparse(dat) else np.asarray(dat, dtype=np.float64) + dat = dat.astype(np.float64, copy=False) + mean = dat.mean(axis=0) + std = dat.std(axis=0, ddof=1) + std[std == 0] = 1.0 + scaled = (dat - mean) / std + np.clip(scaled, None, scale_max, out=scaled) + return scaled + + def _get_mixscale_markers( + self, + adata: AnnData, + *, + pert_key: str, + control: str, + de_layer: str | None, + test_method: str, + logfc_threshold: float, + pval_cutoff: float, + min_de_genes: int, + max_de_genes: int, + fine_mode: bool, + fine_mode_labels: str, + de_genes_by_target: Mapping[str, Sequence[str]] | None, + harmonize: bool, + harmonize_min_proportion: float, + split_by: str | None, + random_state: int, + ) -> dict[str, np.ndarray]: + """Determine the large-effect DE genes for each target gene, pooling across all cells. + + Returns a mapping from target gene to the ordered array of DE gene names, empty when fewer than `min_de_genes` survive the filters. + """ + var_names = set(adata.var_names) + gene_targets = set(adata.obs[pert_key]).difference([control]) + nt_cells = adata.obs_names[adata.obs[pert_key] == control] + markers: dict[str, np.ndarray] = {} + + for gene in gene_targets: + if de_genes_by_target is not None: + supplied = list(dict.fromkeys(g for g in de_genes_by_target.get(gene, ()) if g in var_names)) + if len(supplied) == 0: + warnings.warn( + f"No DE genes provided for perturbation {gene!r} in de_genes_by_target; it will not be scored.", + stacklevel=2, + ) + de_genes = np.array(supplied, dtype=object) + else: + guide_cells = adata.obs_names[adata.obs[pert_key] == gene] + ref_cells = nt_cells + if harmonize and split_by is not None: + ref_cells = self._harmonize_control_cells( + adata, + target_cells=guide_cells, + control_cells=nt_cells, + split_by=split_by, + min_proportion=harmonize_min_proportion, + random_state=random_state, + ) + + if fine_mode: + pooled: list[str] = [] + for guide in adata.obs.loc[guide_cells, fine_mode_labels].unique(): + guide_subset = guide_cells[adata.obs.loc[guide_cells, fine_mode_labels] == guide] + for g in self._de_for_pair( + adata, + guide_subset, + ref_cells, + de_layer=de_layer, + test_method=test_method, + logfc_threshold=logfc_threshold, + pval_cutoff=pval_cutoff, + ): + if g not in pooled: + pooled.append(g) + de_genes = np.array(pooled, dtype=object) + else: + de_genes = self._de_for_pair( + adata, + guide_cells, + ref_cells, + de_layer=de_layer, + test_method=test_method, + logfc_threshold=logfc_threshold, + pval_cutoff=pval_cutoff, + ) + + if len(de_genes) > max_de_genes: + de_genes = de_genes[:max_de_genes] + if len(de_genes) < min_de_genes: + de_genes = np.array([], dtype=object) + markers[gene] = de_genes + + return markers + + @staticmethod + def _de_for_pair( + adata: AnnData, + group_cells, + reference_cells, + *, + de_layer: str | None, + test_method: str, + logfc_threshold: float, + pval_cutoff: float, + ) -> np.ndarray: + """Wilcoxon-style DE between two cell sets; returns gene names passing the filters, sorted by raw p-value.""" + group_cells = pd.Index(group_cells) + reference_cells = pd.Index(reference_cells) + if len(group_cells) == 0 or len(reference_cells) == 0: + return np.array([], dtype=object) + + sub = adata[adata.obs_names.isin(group_cells.union(reference_cells))].copy() + groups = np.where(sub.obs_names.isin(group_cells), "perturbed", "control") + sub.obs["_mixscale_de"] = pd.Categorical(groups, categories=["control", "perturbed"]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + warnings.simplefilter("ignore", PerformanceWarning) + sc.tl.rank_genes_groups( + sub, + layer=de_layer, + groupby="_mixscale_de", + groups=["perturbed"], + reference="control", + method=test_method, + use_raw=False, + ) + result = sub.uns["rank_genes_groups"] + names = np.asarray(result["names"]["perturbed"]) + logfoldchanges = np.asarray(result["logfoldchanges"]["perturbed"]) + pvals = np.asarray(result["pvals"]["perturbed"]) + pvals_adj = np.asarray(result["pvals_adj"]["perturbed"]) + + keep = (np.abs(logfoldchanges) >= logfc_threshold) & (pvals_adj < pval_cutoff) + names, pvals = names[keep], pvals[keep] + return names[np.argsort(pvals, kind="stable")] + + @staticmethod + def _harmonize_control_cells( + adata: AnnData, + *, + target_cells, + control_cells, + split_by: str, + min_proportion: float, + random_state: int, + ) -> pd.Index: + """Subsample control cells so their per-group composition matches the perturbed cells. + + The control subsampling uses NumPy's random generator and therefore does not reproduce the exact cells drawn by the Mixscale R package, but follows the same per-group counting logic. + The R reference gates harmonization on the number of split columns, making it a no-op for a single split column; here it activates whenever the single split column has more than one group. + """ + target_cells = pd.Index(target_cells) + control_cells = pd.Index(control_cells) + split = adata.obs[split_by] + groups = list(split.unique()) + if len(groups) <= 1: + return control_cells + + rng = np.random.default_rng(random_state) + target_split = split.loc[target_cells] + control_split = split.loc[control_cells] + + active_groups = list(groups) + while True: + n_target = np.array([(target_split == g).sum() for g in active_groups], dtype=float) + n_control = np.array([(control_split == g).sum() for g in active_groups], dtype=float) + if n_target.sum() == 0 or (n_target == 0).any() or (n_control == 0).any(): + # cannot harmonize cleanly; fall back to all control cells + return control_cells + prop_target = n_target / n_target.sum() + total_desired = np.floor((n_control / prop_target).min()) + if total_desired >= min_proportion * n_control.sum(): + break + del active_groups[int(np.argmin(n_control))] + if len(active_groups) <= 1: + return control_cells + + desired = np.floor(total_desired * prop_target).astype(int) + sampled: list[str] = [] + for group, n_desired in zip(active_groups, desired, strict=True): + pool = control_cells[(control_split == group).to_numpy()] + n_draw = min(int(n_desired), len(pool)) + sampled.extend(rng.choice(np.asarray(pool), size=n_draw, replace=False).tolist()) + return pd.Index(sampled) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_perturbation_efficacy/_mixscape.py similarity index 81% rename from pertpy/tools/_mixscape.py rename to pertpy/tools/_perturbation_efficacy/_mixscape.py index 5eb6acb7..e554c3ee 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_perturbation_efficacy/_mixscape.py @@ -10,16 +10,15 @@ import pandas as pd import scanpy as sc import seaborn as sns -from fast_array_utils.stats import mean, mean_var from pandas.errors import PerformanceWarning from scanpy import get from scanpy._utils import check_use_raw, sanitize_anndata from scanpy.plotting import _utils -from scanpy.tools._utils import _choose_representation -from scipy.sparse import csr_matrix, issparse, spmatrix +from scipy.sparse import spmatrix from sklearn.mixture import GaussianMixture from pertpy._doc import _doc_params, doc_common_plot_args +from pertpy.tools._perturbation_efficacy._base import PerturbationEfficacyAnalyzer if TYPE_CHECKING: from collections.abc import Sequence @@ -30,149 +29,8 @@ from matplotlib.pyplot import Figure -class Mixscape: - """identify perturbation effects in CRISPR screens by separating cells into perturbation groups.""" - - def __init__(self): - pass - - def perturbation_signature( - self, - adata: AnnData, - pert_key: str, - control: str, - *, - ref_selection_mode: Literal["nn", "split_by"] = "nn", - split_by: str | None = None, - n_neighbors: int = 20, - use_rep: str | None = None, - n_dims: int | None = 15, - n_pcs: int | None = None, - batch_size: int | None = None, - copy: bool = False, - **kwargs, - ): - """Calculate perturbation signature. - - The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged - mRNA expression profile of the control cells (selected according to `ref_selection_mode`). - The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the - perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same. - - Args: - adata: The annotated data object. - pert_key: The column of `.obs` with perturbation categories, should also contain `control`. - control: Name of the control category from the `pert_key` column. - ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`, - the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`, - the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. - split_by: Provide the column `.obs` if multiple biological replicates exist to calculate - the perturbation signature for every replicate separately. - n_neighbors: Number of neighbors from the control to use for the perturbation signature. - use_rep: Use the indicated representation. `'X'` or any key for `.obsm` is valid. - If `None`, the representation is chosen automatically: - For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. - If 'X_pca' is not present, it's computed with default parameters. - n_dims: Number of dimensions to use from the representation to calculate the perturbation signature. - If `None`, use all dimensions. - n_pcs: If PCA representation is used, the number of principal components to compute. - If `n_pcs==0` use `.X` if `use_rep is None`. - batch_size: Size of batch to calculate the perturbation signature. - If 'None', the perturbation signature is calcuated in the full mode, requiring more memory. - The batched mode is very inefficient for sparse data. - copy: Determines whether a copy of the `adata` is returned. - **kwargs: Additional arguments for the `NNDescent` class from `pynndescent`. - - Returns: - If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`. - Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`. - - Examples: - Calcutate perturbation signature for each cell in the dataset: - - >>> import pertpy as pt - >>> mdata = pt.dt.papalexi_2021() - >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") - """ - if ref_selection_mode not in ["nn", "split_by"]: - raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.") - if ref_selection_mode == "split_by" and split_by is None: - raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.") - - if copy: - adata = adata.copy() - - adata.layers["X_pert"] = adata.X.copy() - - # Work with LIL for efficient indexing but don't store it in AnnData as LIL is not supported anymore - X_pert_lil = adata.layers["X_pert"].tolil() if issparse(adata.layers["X_pert"]) else adata.layers["X_pert"] - - control_mask = adata.obs[pert_key] == control - - if ref_selection_mode == "split_by": - for split in adata.obs[split_by].unique(): - split_mask = adata.obs[split_by] == split - control_mask_group = control_mask & split_mask - control_mean_expr = mean(adata.X[control_mask_group], axis=0) - X_pert_lil[split_mask] = ( - np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - X_pert_lil[split_mask] - ) - else: - if split_by is None: - split_masks = [np.full(adata.n_obs, True, dtype=bool)] - else: - split_obs = adata.obs[split_by] - split_masks = [split_obs == cat for cat in split_obs.unique()] - - representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) - if n_dims is not None and n_dims < representation.shape[1]: - representation = representation[:, :n_dims] - - from pynndescent import NNDescent - - for split_mask in split_masks: - control_mask_split = control_mask & split_mask - R_split = representation[split_mask] - R_control = representation[np.asarray(control_mask_split)] - eps = kwargs.pop("epsilon", 0.1) - nn_index = NNDescent(R_control, **kwargs) - indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) - X_control = np.expm1(adata.X[np.asarray(control_mask_split)]) - n_split = split_mask.sum() - n_control = X_control.shape[0] - - if batch_size is None: - col_indices = np.ravel(indices) - row_indices = np.repeat(np.arange(n_split), n_neighbors) - neigh_matrix = csr_matrix( - (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), - shape=(n_split, n_control), - ) - neigh_matrix /= n_neighbors - X_pert_lil[np.asarray(split_mask)] = ( - sc.pp.log1p(neigh_matrix @ X_control) - X_pert_lil[np.asarray(split_mask)] - ) - else: - split_indices = np.where(split_mask)[0] - for i in range(0, n_split, batch_size): - size = min(i + batch_size, n_split) - select = slice(i, size) - batch = np.ravel(indices[select]) - split_batch = split_indices[select] - size = size - i - means_batch = X_control[batch] - batch_reshaped = means_batch.reshape(size, n_neighbors, -1) - means_batch, _ = mean_var(batch_reshaped, axis=1) - X_pert_lil[split_batch] = np.log1p(means_batch) - X_pert_lil[split_batch] - - if issparse(X_pert_lil): - adata.layers["X_pert"] = X_pert_lil.tocsr() - else: - adata.layers["X_pert"] = X_pert_lil - - if copy: - return adata +class Mixscape(PerturbationEfficacyAnalyzer): + """Identify perturbation effects in CRISPR screens by separating cells into perturbation groups.""" def mixscape( self, @@ -498,70 +356,6 @@ def lda( if copy: return adata - def _get_perturbation_markers( - self, - adata: AnnData, - *, - split_masks: list[np.ndarray], - categories: list[str], - pert_key: str, - control: str, - layer: str, - pval_cutoff: float, - min_de_genes: float, - logfc_threshold: float, - test_method: str, - ) -> dict[tuple, np.ndarray]: - """Determine gene sets across all splits/groups through differential gene expression. - - Args: - adata: :class:`~anndata.AnnData` object - split_masks: List of boolean masks for each split/group. - categories: List of split/group names. - pert_key: The column of `.obs` with target gene labels. - control: Control category from the `labels` column. - layer: Key from adata.layers whose value will be used to compare gene expression. - pval_cutoff: P-value cut-off for selection of significantly DE genes. - min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. - logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. - test_method: Method to use for differential expression testing. - - Returns: - Set of column indices. - """ - perturbation_markers: dict[tuple, np.ndarray] = {} # type: ignore - for split, split_mask in enumerate(split_masks): - category = categories[split] - # get gene sets for each split - gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control])) - adata_split = adata[split_mask].copy() - # find top DE genes between cells with targeting and non-targeting gRNAs - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - warnings.simplefilter("ignore", PerformanceWarning) - sc.tl.rank_genes_groups( - adata_split, - layer=layer, - groupby=pert_key, - groups=gene_targets, - reference=control, - method=test_method, - use_raw=False, - ) - # get DE genes for each target gene - for gene in gene_targets: - logfc_threshold_mask = ( - np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold - ) - de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask] - pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask] - de_genes = de_genes[pvals_adj < pval_cutoff] - if len(de_genes) < min_de_genes: - de_genes = np.array([]) - perturbation_markers[(category, gene)] = de_genes - - return perturbation_markers - @_doc_params(common_plot_args=doc_common_plot_args) def plot_barplot( # pragma: no cover # noqa: D417 self, diff --git a/tests/tools/_perturbation_efficacy/__init__.py b/tests/tools/_perturbation_efficacy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tools/_perturbation_efficacy/conftest.py b/tests/tools/_perturbation_efficacy/conftest.py new file mode 100644 index 00000000..bf54cf2e --- /dev/null +++ b/tests/tools/_perturbation_efficacy/conftest.py @@ -0,0 +1,41 @@ +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from scipy import sparse + +NUM_CELLS_PER_GROUP = 10 +NUM_NOT_DE = 10 +NUM_DE = 10 + + +@pytest.fixture +def adata(): + """Synthetic screen with NT controls and a target gene split into non-perturbed and knocked-out cells.""" + rng = np.random.default_rng(seed=1) + columns = [] + # genes that are not differentially expressed between any group + for _ in range(NUM_NOT_DE): + nt = np.clip(rng.normal(0, 1, NUM_CELLS_PER_GROUP), 0, None) + non_perturbed = np.clip(rng.normal(0, 1, NUM_CELLS_PER_GROUP), 0, None) + knockout = np.clip(rng.normal(0, 1, NUM_CELLS_PER_GROUP), 0, None) + columns.append(np.concatenate((nt, non_perturbed, knockout))[:, None]) + + # genes that are differentially expressed only in the knocked-out cells + for i in range(NUM_DE): + nt = np.clip(rng.normal(i + 2, 0.5 + 0.05 * i, NUM_CELLS_PER_GROUP), 0, None) + non_perturbed = np.clip(rng.normal(i + 2, 0.5 + 0.05 * i, NUM_CELLS_PER_GROUP), 0, None) + knockout = np.clip(rng.normal(i + 4, 0.5 + 0.1 * i, NUM_CELLS_PER_GROUP), 0, None) + columns.append(np.concatenate((nt, non_perturbed, knockout))[:, None]) + + X = np.concatenate(columns, axis=1) + + obs = pd.DataFrame( + { + "gene_target": ["NT"] * NUM_CELLS_PER_GROUP + ["target_gene_a"] * NUM_CELLS_PER_GROUP * 2, + "label": ["control"] * NUM_CELLS_PER_GROUP + ["treatment"] * NUM_CELLS_PER_GROUP * 2, + }, + index=np.arange(NUM_CELLS_PER_GROUP * 3).astype(str), + ) + var = pd.DataFrame(index=[f"gene{i}" for i in range(1, NUM_NOT_DE + NUM_DE + 1)]) + return ad.AnnData(X=sparse.csr_matrix(X), obs=obs, var=var) diff --git a/tests/tools/_perturbation_efficacy/test_mixscale.py b/tests/tools/_perturbation_efficacy/test_mixscale.py new file mode 100644 index 00000000..40aaa711 --- /dev/null +++ b/tests/tools/_perturbation_efficacy/test_mixscale.py @@ -0,0 +1,229 @@ +"""Tests for Mixscale continuous perturbation scoring. + +The `R_GOLDEN` values were produced by `RunMixscale` from https://github.com/satijalab/Mixscale on the dataset built by the `parity_adata` fixture. +The same DE genes are supplied via its `DE.gene` argument so the comparison isolates the scoring algorithm from the DE test, with matching `scale` (`slot`) and `split.by` settings. +The pertpy and R scores agree to within floating-point round-off (~1e-13); the values below are rounded. +""" + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from scipy import sparse + +import pertpy as pt + +NUM_CELLS_PER_GROUP = 10 + +DE_GENES_BY_TARGET = { + "GeneA": [f"Gene{i}" for i in range(8)], + "GeneB": [f"Gene{i}" for i in range(8, 13)], +} + +# Per-cell scores from R's RunMixscale; cells are Cell0..Cell23 in order (the first 8 are NT controls). +R_GOLDEN = { + ("nosplit", False): [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 8.514691542, + 7.547581559, + 8.308927689, + 7.027790116, + 8.856137701, + 10.27242998, + 7.472028438, + 10.44919624, + 8.222717541, + 6.790026264, + 3.645112649, + 3.338441385, + 4.270389556, + 4.392104461, + 3.291505505, + 2.442656334, + ], + ("replicate", False): [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 8.269530405, + 8.259218669, + 8.242553632, + 7.988511579, + 8.608806982, + 11.32851433, + 7.475598635, + 11.64795266, + 8.041098786, + 7.6848051, + 4.474797496, + 7.19368398, + 4.838105457, + 7.512732638, + 4.26093846, + 4.971528379, + ], + ("nosplit", True): [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 8.571392513, + 7.646342068, + 8.572557144, + 7.028868448, + 8.726570745, + 9.937827644, + 7.709038712, + 10.97100301, + 7.509618135, + 6.579441777, + 3.681654505, + 3.205923984, + 4.255066398, + 4.093707024, + 3.681669612, + 2.618569012, + ], +} + + +@pytest.fixture(params=[np.asarray, sparse.csr_array], ids=["dense", "csr_array"]) +def parity_adata(request): + """The exact deterministic dataset the R golden scores were computed on, over each supported array type.""" + rng = np.random.default_rng(0) + n_genes = 16 + labels = ["NT"] * 8 + ["GeneA"] * 10 + ["GeneB"] * 6 + n_cells = len(labels) + x_pert = rng.normal(0, 1, (n_cells, n_genes)) + x_pert[8:18, :8] += 2.5 + x_pert[18:24, 8:13] -= 1.7 + + adata = ad.AnnData( + X=np.zeros((n_cells, n_genes)), + obs=pd.DataFrame( + {"gene_target": labels, "replicate": ["rep1" if i % 2 == 0 else "rep2" for i in range(n_cells)]}, + index=[f"Cell{i}" for i in range(n_cells)], + ), + var=pd.DataFrame(index=[f"Gene{i}" for i in range(n_genes)]), + ) + adata.layers["X_pert"] = request.param(x_pert) + return adata + + +@pytest.fixture +def scored_adata(adata): + """The shared mixscape fixture scored with Mixscale (dense signature, t-test DE).""" + adata.layers["X_pert"] = adata.X.toarray() + pt.tl.Mixscale().mixscale(adata, pert_key="gene_target", control="NT", test_method="t-test", min_de_genes=3) + return adata + + +def test_control_cells_score_zero(scored_adata): + nt = scored_adata.obs["gene_target"] == "NT" + assert (scored_adata.obs.loc[nt, "mixscale_score"] == 0).all() + + +def test_perturbed_cells_score_nonzero(scored_adata): + perturbed = scored_adata.obs["gene_target"] != "NT" + assert scored_adata.obs.loc[perturbed, "mixscale_score"].abs().sum() > 0 + + +def test_knockout_scores_exceed_non_perturbed(scored_adata): + scores = scored_adata.obs["mixscale_score"].to_numpy() + non_perturbed = np.abs(scores[NUM_CELLS_PER_GROUP : NUM_CELLS_PER_GROUP * 2]).mean() + knockout = np.abs(scores[NUM_CELLS_PER_GROUP * 2 :]).mean() + assert knockout > non_perturbed + + +def test_score_dtype_and_stored_results(scored_adata): + assert scored_adata.obs["mixscale_score"].dtype == float + assert "mixscale" in scored_adata.uns + assert "mixscale_de_genes" in scored_adata.uns + + +def test_requires_perturbation_signature(adata): + with pytest.raises(KeyError, match="X_pert"): + pt.tl.Mixscale().mixscale(adata, pert_key="gene_target", control="NT") + + +def test_custom_score_column(adata): + adata.layers["X_pert"] = adata.X.toarray() + pt.tl.Mixscale().mixscale( + adata, pert_key="gene_target", control="NT", test_method="t-test", min_de_genes=3, new_class_name="efficacy" + ) + assert "efficacy" in adata.obs + assert "mixscale_score" not in adata.obs + + +def test_copy_does_not_mutate_input(adata): + adata.layers["X_pert"] = adata.X.toarray() + result = pt.tl.Mixscale().mixscale( + adata, pert_key="gene_target", control="NT", test_method="t-test", min_de_genes=3, copy=True + ) + assert result is not adata + assert "mixscale_score" in result.obs + assert "mixscale_score" not in adata.obs + + +def test_de_genes_by_target_override(adata): + adata.layers["X_pert"] = adata.X.toarray() + de_genes = {"target_gene_a": [f"gene{i}" for i in range(11, 19)]} + pt.tl.Mixscale().mixscale(adata, pert_key="gene_target", control="NT", de_genes_by_target=de_genes, min_de_genes=1) + assert list(adata.uns["mixscale_de_genes"]["target_gene_a"]) == de_genes["target_gene_a"] + + +def test_too_few_de_genes_fall_back_to_one(adata): + adata.layers["X_pert"] = adata.X.toarray() + pt.tl.Mixscale().mixscale(adata, pert_key="gene_target", control="NT", test_method="t-test", min_de_genes=1000) + perturbed = adata.obs["gene_target"] != "NT" + assert (adata.obs.loc[perturbed, "mixscale_score"] == 1.0).all() + assert (adata.obs.loc[~perturbed, "mixscale_score"] == 0.0).all() + + +@pytest.mark.parametrize("array_type", [np.asarray, sparse.csr_array], ids=["dense", "csr_array"]) +def test_signature_and_score_end_to_end(adata, array_type): + adata.X = array_type(adata.X.toarray()) + ms = pt.tl.Mixscale() + ms.perturbation_signature(adata, pert_key="gene_target", control="NT") + ms.mixscale(adata, pert_key="gene_target", control="NT", test_method="t-test", min_de_genes=3) + + scores = adata.obs["mixscale_score"] + assert not scores.isna().any() + assert (scores[adata.obs["gene_target"] == "NT"] == 0).all() + assert scores[adata.obs["gene_target"] != "NT"].abs().sum() > 0 + + +@pytest.mark.parametrize(("split_by", "scale"), list(R_GOLDEN)) +def test_matches_r_reference(parity_adata, split_by, scale): + pt.tl.Mixscale().mixscale( + parity_adata, + pert_key="gene_target", + control="NT", + layer="X_pert", + de_genes_by_target=DE_GENES_BY_TARGET, + min_de_genes=1, + max_de_genes=1000, + split_by=None if split_by == "nosplit" else split_by, + scale=scale, + ) + np.testing.assert_allclose( + parity_adata.obs["mixscale_score"].to_numpy(), + np.array(R_GOLDEN[(split_by, scale)]), + atol=1e-6, + ) diff --git a/tests/tools/_perturbation_efficacy/test_mixscape.py b/tests/tools/_perturbation_efficacy/test_mixscape.py new file mode 100644 index 00000000..5844fa3a --- /dev/null +++ b/tests/tools/_perturbation_efficacy/test_mixscape.py @@ -0,0 +1,100 @@ +import anndata as ad +import numpy as np +import pandas as pd + +import pertpy as pt +from pertpy.tools._perturbation_efficacy._mixscape import MixscapeGaussianMixture + +NUM_CELLS_PER_GROUP = 10 +ACCURACY_THRESHOLD = 0.8 + + +def test_mixscape(adata): + adata.layers["X_pert"] = adata.X + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.mixscape(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") + + np_result = adata.obs["mixscape_class_global"] == "NP" + np_result_correct = np_result[NUM_CELLS_PER_GROUP : NUM_CELLS_PER_GROUP * 2] + ko_result = adata.obs["mixscape_class_global"] == "KO" + ko_result_correct = ko_result[NUM_CELLS_PER_GROUP * 2 : NUM_CELLS_PER_GROUP * 3] + + assert "mixscape_class" in adata.obs + assert "mixscape_class_global" in adata.obs + assert "mixscape_class_p_ko" in adata.obs + assert sum(np_result_correct) > ACCURACY_THRESHOLD * NUM_CELLS_PER_GROUP + assert sum(ko_result_correct) > ACCURACY_THRESHOLD * NUM_CELLS_PER_GROUP + + +def test_perturbation_signature(adata): + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.perturbation_signature(adata, pert_key="label", control="control") + + assert "X_pert" in adata.layers + + +def test_lda(adata): + adata.layers["X_pert"] = adata.X + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.mixscape(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") + mixscape_identifier.lda(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") + + assert "mixscape_lda" in adata.uns + + +def test_deterministic_perturbation_signature(): + n_genes = 5 + n_cells_per_class = 50 + cell_classes = ["NT", "KO", "NP"] + groups = ["Group1", "Group2"] + + cell_classes_array = np.repeat(cell_classes, n_cells_per_class) + groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes)) + obs = pd.DataFrame( + { + "cell_class": cell_classes_array, + "group": groups_array, + "perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array], + } + ) + + data = np.zeros((len(obs), n_genes)) + pert_effect = np.random.default_rng().uniform(-1, 1, size=(n_cells_per_class // len(groups), n_genes)) + for group in groups: + baseline_expr = 2 if group == "Group1" else 10 + group_mask = obs["group"] == group + data[(obs["cell_class"] == "NT") & group_mask] = baseline_expr + data[(obs["cell_class"] == "KO") & group_mask] = baseline_expr + pert_effect + data[(obs["cell_class"] == "NP") & group_mask] = baseline_expr + + var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)]) + adata = ad.AnnData(X=data, obs=obs, var=var) + + for ref_selection_mode in ("nn", "split_by"): + adata.layers.pop("X_pert", None) + pt.tl.Mixscape().perturbation_signature( + adata, + pert_key="perturbation", + control="control", + ref_selection_mode=ref_selection_mode, + n_neighbors=5, + split_by="group", + ) + assert "X_pert" in adata.layers + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) + assert np.allclose( + adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0) + ) + + +def test_mixscape_gaussian_mixture(): + X = np.random.default_rng().random(100) + fixed_means = [0.2, None] + fixed_covariances = [None, 0.1] + + model = MixscapeGaussianMixture(n_components=2, fixed_means=fixed_means, fixed_covariances=fixed_covariances) + model.fit(X.reshape(-1, 1)) + + assert np.allclose(model.means_[0], fixed_means[0]) + assert np.allclose(model.covariances_[1], fixed_covariances[1]) diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py deleted file mode 100644 index 342c6538..00000000 --- a/tests/tools/test_mixscape.py +++ /dev/null @@ -1,174 +0,0 @@ -from pathlib import Path - -import anndata -import numpy as np -import pandas as pd -import pytest -from scipy import sparse - -import pertpy as pt -from pertpy.tools._mixscape import MixscapeGaussianMixture - -# Random generate data settings -NUM_CELLS_PER_GROUP = 10 -NUM_NOT_DE = 10 -NUM_DE = 10 -ACCURACY_THRESHOLD = 0.8 - - -@pytest.fixture -def adata(): - rng = np.random.default_rng(seed=1) - # generate not differentially expressed genes - for i in range(NUM_NOT_DE): - NT = rng.normal(0, 1, NUM_CELLS_PER_GROUP) - NT = np.where(NT < 0, 0, NT) - NP = rng.normal(0, 1, NUM_CELLS_PER_GROUP) - NP = np.where(NP < 0, 0, NP) - KO = rng.normal(0, 1, NUM_CELLS_PER_GROUP) - KO = np.where(KO < 0, 0, KO) - gene_i = np.concatenate((NT, NP, KO)) - gene_i = np.expand_dims(gene_i, axis=1) - if i == 0: # noqa: SIM108 - X = gene_i - else: - X = np.concatenate((X, gene_i), axis=1) - - # generate differentially expressed genes - for i in range(NUM_DE): - NT = rng.normal(i + 2, 0.5 + 0.05 * i, NUM_CELLS_PER_GROUP) - NT = np.where(NT < 0, 0, NT) - NP = rng.normal(i + 2, 0.5 + 0.05 * i, NUM_CELLS_PER_GROUP) - NP = np.where(NP < 0, 0, NP) - KO = rng.normal(i + 4, 0.5 + 0.1 * i, NUM_CELLS_PER_GROUP) - KO = np.where(KO < 0, 0, KO) - gene_i = np.concatenate((NT, NP, KO)) - gene_i = np.expand_dims(gene_i, axis=1) - X = np.concatenate((X, gene_i), axis=1) - - # obs for random AnnData - gene_target = {"gene_target": ["NT"] * NUM_CELLS_PER_GROUP + ["target_gene_a"] * NUM_CELLS_PER_GROUP * 2} - gene_target = pd.DataFrame(gene_target) - label = {"label": ["control"] * NUM_CELLS_PER_GROUP + ["treatment"] * NUM_CELLS_PER_GROUP * 2} - label = pd.DataFrame(label) - obs = pd.concat([gene_target, label], axis=1) - obs = obs.set_index(np.arange(NUM_CELLS_PER_GROUP * 3)) - obs.index.rename("index", inplace=True) - - # var for random AnnData - var_data = {"name": ["gene" + str(i) for i in range(1, NUM_NOT_DE + NUM_DE + 1)]} - var = pd.DataFrame(var_data) - var = var.set_index("name", drop=False) - var.index.rename("index", inplace=True) - - X = sparse.csr_matrix(X) - adata = anndata.AnnData(X=X, obs=obs, var=var) - - return adata - - -def test_mixscape(adata): - adata.layers["X_pert"] = adata.X - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.mixscape(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") - np_result = adata.obs["mixscape_class_global"] == "NP" - np_result_correct = np_result[NUM_CELLS_PER_GROUP : NUM_CELLS_PER_GROUP * 2] - - ko_result = adata.obs["mixscape_class_global"] == "KO" - ko_result_correct = ko_result[NUM_CELLS_PER_GROUP * 2 : NUM_CELLS_PER_GROUP * 3] - - assert "mixscape_class" in adata.obs - assert "mixscape_class_global" in adata.obs - assert "mixscape_class_p_ko" in adata.obs - assert sum(np_result_correct) > ACCURACY_THRESHOLD * NUM_CELLS_PER_GROUP - assert sum(ko_result_correct) > ACCURACY_THRESHOLD * NUM_CELLS_PER_GROUP - - -def test_perturbation_signature(adata): - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.perturbation_signature(adata, pert_key="label", control="control") - - assert "X_pert" in adata.layers - - -def test_lda(adata): - adata.layers["X_pert"] = adata.X - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.mixscape(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") - mixscape_identifier.lda(adata=adata, pert_key="gene_target", control="NT", test_method="t-test") - - assert "mixscape_lda" in adata.uns - - -def test_deterministic_perturbation_signature(): - n_genes = 5 - n_cells_per_class = 50 - cell_classes = ["NT", "KO", "NP"] - groups = ["Group1", "Group2"] - - cell_classes_array = np.repeat(cell_classes, n_cells_per_class) - groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes)) - obs = pd.DataFrame( - { - "cell_class": cell_classes_array, - "group": groups_array, - "perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array], - } - ) - - data = np.zeros((len(obs), n_genes)) - pert_effect = np.random.default_rng().uniform(-1, 1, size=(n_cells_per_class // len(groups), n_genes)) - for _, group in enumerate(groups): - baseline_expr = 2 if group == "Group1" else 10 - group_mask = obs["group"] == group - - nt_mask = (obs["cell_class"] == "NT") & group_mask - data[nt_mask] = baseline_expr - - ko_mask = (obs["cell_class"] == "KO") & group_mask - data[ko_mask] = baseline_expr + pert_effect - - np_mask = (obs["cell_class"] == "NP") & group_mask - data[np_mask] = baseline_expr - - var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)]) - adata = anndata.AnnData(X=data, obs=obs, var=var) - - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.perturbation_signature( - adata, pert_key="perturbation", control="control", n_neighbors=5, split_by="group" - ) - - assert "X_pert" in adata.layers - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) - assert np.allclose( - adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0) - ) - - del adata.layers["X_pert"] - - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.perturbation_signature( - adata, pert_key="perturbation", control="control", ref_selection_mode="split_by", split_by="group" - ) - - assert "X_pert" in adata.layers - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) - assert np.allclose( - adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0) - ) - - -def test_mixscape_gaussian_mixture(): - X = np.random.default_rng().random(100) - - fixed_means = [0.2, None] - fixed_covariances = [None, 0.1] - - model = MixscapeGaussianMixture(n_components=2, fixed_means=fixed_means, fixed_covariances=fixed_covariances) - model.fit(X.reshape(-1, 1)) - - assert np.allclose(model.means_[0], fixed_means[0]) - assert np.allclose(model.covariances_[1], fixed_covariances[1])