Source code for SpaceTravLR.models.parallel_estimators

import copy
from anndata import AnnData
import enlighten
from sklearn.metrics import r2_score
import torch
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, Dataset
from sklearn.linear_model import ARDRegression, BayesianRidge
from group_lasso import GroupLasso
from SpaceTravLR.models.spatial_map import xyc2spatial_fast
from SpaceTravLR.tools.network import RegulatoryFactory, expand_paired_interactions
from .pixel_attention import CellularNicheNetwork, CellularViT
from ..tools.utils import gaussian_kernel_2d, is_mouse_data, set_seed
from ..tools.network import get_cellchat_db

# import commot as ct
from scipy.spatial.distance import cdist
import numba
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import pickle
from easydict import EasyDict as edict
from matplotlib.colors import rgb2hex

tt = torch.tensor
set_seed(42)


import warnings

warnings.filterwarnings("ignore")


@numba.njit(parallel=True)
def calculate_weighted_ligands(gauss_weights, lig_df_values, u_ligands):
    n_ligands = len(u_ligands)
    n_cells = len(gauss_weights)
    weighted_ligands = np.zeros((n_ligands, n_cells))

    for i in numba.prange(n_ligands):
        for j in range(n_cells):
            weighted_ligands[i, j] = np.mean(gauss_weights[j] * lig_df_values[:, i])

    return weighted_ligands


def compute_radius_weights(xy, lig_df, radius, scale_factor):
    ligands = lig_df.columns
    gauss_weights = [
        scale_factor * gaussian_kernel_2d(xy[i], xy, radius=radius)
        for i in range(len(lig_df))
    ]
    u_ligands = list(np.unique(ligands))
    lig_df_values = lig_df[u_ligands].values
    weighted_ligands = calculate_weighted_ligands(
        gauss_weights, lig_df_values, u_ligands
    )

    return pd.DataFrame(weighted_ligands, index=u_ligands, columns=lig_df.index).T

@numba.njit(parallel=True)
def _gaussian_kernel_2d_batch(xy: np.ndarray, radius: float) -> np.ndarray:
    """Compute full N×N weight matrix in one parallelized pass."""
    n = xy.shape[0]
    W = np.empty((n, n), dtype=np.float64)
    inv_2r2 = -1.0 / (2.0 * radius * radius)
    for i in numba.prange(n):
        xi, yi = xy[i, 0], xy[i, 1]
        for j in range(n):
            dx = xi - xy[j, 0]
            dy = yi - xy[j, 1]
            W[i, j] = np.exp((dx * dx + dy * dy) * inv_2r2)
    return W  # (n_cells, n_cells)


@numba.njit(parallel=True)
def _weighted_mean(W: np.ndarray, lig_values: np.ndarray) -> np.ndarray:
    n = W.shape[0]
    n_lig = lig_values.shape[1]
    out = np.zeros((n, n_lig), dtype=np.float64)
    for i in numba.prange(n):
        for k in range(n):
            w = W[i, k]
            for j in range(n_lig):
                out[i, j] += w * lig_values[k, j]
        for j in range(n_lig):
            out[i, j] /= n
    return out


def compute_radius_weights_fast(xy, lig_df, radius, scale_factor):
    W = scale_factor * _gaussian_kernel_2d_batch(
        np.ascontiguousarray(xy, dtype=np.float64), radius
    )
    lig_values = np.ascontiguousarray(lig_df.values, dtype=np.float64)
    result = _weighted_mean(W, lig_values)
    return pd.DataFrame(result, index=lig_df.index, columns=lig_df.columns)




[docs] def received_ligands(xy, ligands_df, lr_info, scale_factor=1): """ Compute the amount of ligand received on the surface of each cell based on location. Args: xy (np.ndarray): Array of spatial coordinates (x, y). ligands_df (pd.DataFrame): ligand gene expression values. Returns: pd.DataFrame: DataFrame of received ligands for each cell. """ lr_info = lr_info.copy() lr_info = lr_info[lr_info["ligand"].isin(np.unique(ligands_df.columns))] lr_info = lr_info[ lr_info["ligand"].isin(np.unique(ligands_df.columns)) ].drop_duplicates(subset="ligand", keep="first") full_df = [] for radius in lr_info["radius"].unique(): radius_ligands = lr_info[lr_info["radius"] == radius]["ligand"].values full_df.append( compute_radius_weights_fast(xy, ligands_df[radius_ligands], radius, scale_factor) ) full_df = pd.concat([df for df in full_df if not df.empty], axis=1) full_df = ( full_df.reindex(ligands_df.index).reindex(ligands_df.columns, axis=1).fillna(0) ) return full_df
[docs] def get_filtered_df(counts_df, cell_thresholds=None, genes=None, min_expression=1e-9): """Get filtered expression of ligands/ receptors based on celltype/ thresholds""" ligand_counts = counts_df[np.unique(genes)] if min_expression > 0: mask = np.where(ligand_counts > min_expression, 1, 0) ligand_counts = ligand_counts * mask if cell_thresholds is not None: cell_thresholds = cell_thresholds.loc[counts_df.index] assert cell_thresholds.index.equals(counts_df.index), ( "error aligning cell_thresholds and counts_df, check if obs_names has duplicates" ) mask = cell_thresholds.reindex(ligand_counts.columns, axis=1).fillna(0).values mask = np.where(mask > 0, 1, 0) ligand_counts = mask * ligand_counts # return ligand_counts.reindex(genes, axis=1) return ligand_counts
[docs] def init_received_ligands( adata, radius, cell_threshes, contact_distance=50, scale_factor=1, layer="imputed_count" ): species = "mouse" if is_mouse_data(adata) else "human" df_ligrec = get_cellchat_db(species) lr = expand_paired_interactions(df_ligrec) lr = lr[lr.ligand.isin(adata.var_names) & (lr.receptor.isin(adata.var_names))] lr["radius"] = np.where( lr["signaling"] == "Secreted Signaling", radius, contact_distance ) counts_df = adata.to_df(layer=layer) ligands = np.unique(lr.ligand) adata.uns["received_ligands"] = received_ligands( xy=adata.obsm["spatial"], ligands_df=get_filtered_df( counts_df, cell_thresholds=cell_threshes, genes=ligands ), lr_info=lr, scale_factor=scale_factor ) adata.uns["received_ligands_tfl"] = received_ligands( xy=adata.obsm["spatial"], ligands_df=get_filtered_df( counts_df, None, genes=ligands ), # Only Commot LRs should be filtered lr_info=lr, ) return adata
[docs] def create_spatial_features(x, y, celltypes, obs_index, radius=200): coords = np.column_stack((x, y)) unique_celltypes = np.unique(celltypes) result = np.zeros((len(x), len(unique_celltypes))) distances = cdist(coords, coords) for i, celltype in enumerate(unique_celltypes): mask = celltypes == celltype neighbors = (distances <= radius)[:, mask] result[:, i] = np.sum(neighbors, axis=1) if result.shape != (len(x), len(unique_celltypes)): raise ValueError(f"Expected: {(len(x), len(unique_celltypes))}") columns = [f"{ct}_within" for ct in unique_celltypes] df = pd.DataFrame(result, columns=columns, index=obs_index) return df
if torch.backends.mps.is_available(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") class RotatedTensorDataset(Dataset): def __init__( self, sp_maps, X_cell, y_cell, cluster, spatial_features, rotate_maps=True ): self.sp_maps = sp_maps self.X_cell = X_cell self.y_cell = y_cell self.cluster = cluster self.spatial_features = spatial_features self.rotate_maps = rotate_maps def __len__(self): return len(self.X_cell) def __getitem__(self, idx): sp_map = self.sp_maps[idx, self.cluster : self.cluster + 1, :, :] if self.rotate_maps: k = np.random.choice([0, 1, 2, 3]) sp_map = np.rot90(sp_map, k=k, axes=(1, 2)) return ( torch.from_numpy(sp_map.copy()).float(), torch.from_numpy(self.X_cell[idx]).float(), torch.from_numpy(np.array(self.y_cell[idx])).float(), torch.from_numpy(self.spatial_features[idx]).float(), ) def init_ligands_and_receptors( species, adata, annot, target_gene, receptor_thresh, radius, contact_distance, tf_ligand_cutoff, regulators, grn, ): ligand_mixtures = edict() # df_ligrec = ct.pp.ligand_receptor_database( # database='CellChat', # species=species, # signaling_type=None # ) # df_ligrec.columns = ['ligand', 'receptor', 'pathway', 'signaling'] df_ligrec = get_cellchat_db(species) lr = expand_paired_interactions(df_ligrec) lr = lr[lr.ligand.isin(adata.var_names) & (lr.receptor.isin(adata.var_names))] receptors = list(lr.receptor.values) _layer = ( "normalized_count" if "normalized_count" in adata.layers else "imputed_count" ) # receptor_levels = adata.to_df(layer=_layer)[np.unique(receptors)].join( # adata.obs[annot]).groupby(annot).mean().max(0).to_frame() # receptor_levels.columns = ['mean_max'] # lr = lr[lr.receptor.isin( # receptor_levels.index[receptor_levels['mean_max'] > receptor_thresh])] lr["radius"] = np.where( lr["signaling"] == "Secreted Signaling", radius, contact_distance ) lr = lr[~((lr.receptor == target_gene) | (lr.ligand == target_gene))] lr["pairs"] = lr.ligand.values + "$" + lr.receptor.values lr = lr.drop_duplicates(subset="pairs", keep="first") ligands = list(lr.ligand.values) receptors = list(lr.receptor.values) # current_dir = os.path.dirname(os.path.abspath(__file__)) data_path = f'/Users/koush/Projects/SpaceOracle/data/ligand_target_{species}.parquet' # data_path = ( # f"https://zenodo.org/records/17594271/files/ligand_target_{species}.parquet" # ) nichenet_lt = pd.read_parquet(data_path) nichenet_lt = nichenet_lt.loc[np.intersect1d(nichenet_lt.index, regulators)][ np.intersect1d(nichenet_lt.columns, ligands) ] tfl_pairs = [] tfl_regulators = [] tfl_ligands = [] if grn is not None: ligand_regulators = { lig: set(grn.get_regulators(adata, lig)) for lig in nichenet_lt.columns } else: from collections import defaultdict ligand_regulators = defaultdict(list) for tf_ in nichenet_lt.index: row = nichenet_lt.loc[tf_] top_5 = row.nlargest(5) for lig_, value in top_5.items(): if ( target_gene not in ligand_regulators[lig_] and tf_ not in ligand_regulators[lig_] and value > tf_ligand_cutoff ): tfl_ligands.append(lig_) tfl_regulators.append(tf_) tfl_pairs.append(f"{lig_}#{tf_}") assert len(ligands) == len(receptors) assert len(tfl_regulators) == len(tfl_ligands) ligand_mixtures.lr = lr ligand_mixtures.ligands = ligands ligand_mixtures.receptors = receptors ligand_mixtures.tfl_pairs = tfl_pairs ligand_mixtures.tfl_regulators = tfl_regulators ligand_mixtures.tfl_ligands = tfl_ligands return ligand_mixtures
[docs] class SpatialCellularProgramsEstimator: """Per-gene spatial regression estimator with ligand-receptor and GRN features. ``SpatialCellularProgramsEstimator`` fits a spatially-aware gene-regulatory model for a single *target gene*. For every annotated cell-type cluster it trains a :class:`~SpaceTravLR.models.pixel_attention.CellularNicheNetwork` (CNN) or :class:`~SpaceTravLR.models.pixel_attention.CellularViT` (ViT) that takes three inputs: 1. **Spatial neighbourhood map** – a 2-D density image of each cell-type derived from ``adata.obsm["spatial"]``. 2. **Regulatory features** – expression of transcription factors inferred from a supplied GRN, along with ligand-receptor interaction scores and ligand-TF interaction scores computed from the CellChat database. 3. **Spatial context features** – per-cell neighbour-count vectors (one column per cell type within ``radius`` microns). The model outputs per-cell *spatial beta* coefficients (one per modulator + intercept) describing how strongly each regulator, L-R pair or TF-ligand pair influences the target gene *at that spatial location*. Pipeline overview:: estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn) estimator.fit(num_epochs=50) betas = estimator.get_betas() # DataFrame (cells × modulators) estimator.export("./models") Attributes: adata (AnnData): The spatial single-cell dataset. target_gene (str): The gene being modelled. regulators (list[str]): Transcription-factor regulators of the target. ligands (list[str]): Ligand genes from active L-R pairs. receptors (list[str]): Receptor genes paired with ``ligands``. tfl_ligands (list[str]): Ligands from TF-ligand (NicheNet) pairs. tfl_regulators (list[str]): TFs paired with ``tfl_ligands``. lr_pairs (pd.Series): Active ligand-receptor pair identifiers (format ``"LigandGene$ReceptorGene"``). tfl_pairs (list[str]): Active TF-ligand pair identifiers (format ``"LigandGene#TFGene"``). modulators (list[str]): Ordered concatenation of regulators, L-R pairs, and TF-ligand pairs – defines the column order of the beta matrix. modulators_genes (list[str]): Unique gene names spanning all modulator categories. models (dict[int, CellularNicheNetwork]): Trained per-cluster neural networks (available after :meth:`fit`). scores (dict[int, float]): Per-cluster R² scores on the training set (available after :meth:`fit`). loss_dict (dict[int, list[float]]): Per-batch MSE losses collected during training (available after :meth:`fit`). spatial_maps (np.ndarray): Cached neighbourhood image tensors of shape ``(n_cells, n_clusters, spatial_dim, spatial_dim)``. spatial_features (pd.DataFrame): Min-max scaled neighbour-count features (shape ``n_cells × n_clusters``). xy (np.ndarray): Spatial coordinates array of shape ``(n_cells, 2)``. device (torch.device): PyTorch compute device (CUDA / MPS / CPU). species (str): ``"mouse"`` or ``"human"`` inferred from gene names. """
[docs] def __init__( self, adata, target_gene, spatial_dim=64, cluster_annot="cell_type_int", layer="imputed_count", radius=100, contact_distance=30, use_ligands=True, tf_ligand_cutoff=0.01, receptor_thresh=0.1, regulators=None, grn=None, colinks_path=None, scale_factor=1, extra_regulators=None, ): """Initialise the estimator and resolve regulatory inputs. The constructor performs three main steps: 1. **GRN lookup** – fetch the list of transcription-factor regulators for ``target_gene`` either from an explicit ``regulators`` list or via ``grn.get_regulators``. 2. **Ligand-receptor discovery** – query the CellChat database to enumerate L-R pairs and NicheNet TF-ligand pairs that are relevant to ``target_gene`` (skipped when ``use_ligands=False``). 3. **Modulator assembly** – build the ordered ``modulators`` list that defines the column ordering of the beta coefficient matrix. Args: adata (AnnData): Spatial dataset. Must contain ``adata.obsm["spatial"]``, the requested ``layer``, and ``cluster_annot`` in ``adata.obs``. target_gene (str): Name of the gene to model. Must be present in ``adata.var_names``. spatial_dim (int): Side length (pixels) of the square spatial neighbourhood image used as CNN input. Default ``64``. cluster_annot (str): ``adata.obs`` column holding integer-encoded cell-type labels. Default ``"cell_type_int"``. layer (str): ``adata.layers`` key for gene-expression values. Default ``"imputed_count"``. radius (float): Search radius (same spatial units as ``adata.obsm["spatial"]``) for *secreted* ligand diffusion. Default ``100``. contact_distance (float): Search radius for *contact* signalling (juxtacrine / cell-surface). Default ``30``. use_ligands (bool): When ``False``, skip all ligand-receptor and TF-ligand feature computation. Default ``True``. tf_ligand_cutoff (float): Minimum NicheNet ligand-target score required to include a TF-ligand pair. Default ``0.01``. receptor_thresh (float): Reserved threshold parameter (currently unused in filtering). Default ``0.1``. regulators (list[str] | None): Explicit list of TF regulators. When supplied, ``grn`` and ``colinks_path`` are ignored. grn (RegulatoryFactory | None): Pre-built GRN object exposing a ``get_regulators(adata, gene)`` method. Mutually exclusive with ``colinks_path``. colinks_path (str | None): Path to a co-link CSV used to construct a :class:`~SpaceTravLR.tools.network.RegulatoryFactory` internally. Required when both ``grn`` and ``regulators`` are ``None``. scale_factor (float): Multiplicative weight applied to Gaussian kernel values during ligand diffusion. Default ``1``. extra_regulators (list[str] | None): Additional gene names to append to the regulator list after GRN lookup. """ assert isinstance(adata, AnnData), "adata must be an AnnData object" assert target_gene in adata.var_names, "target_gene must be in adata.var_names" assert layer in adata.layers, "layer must be in adata.layers" assert cluster_annot in adata.obs.columns, ( "cluster_annot must be in adata.obs.columns" ) self.adata = adata self.scale_factor = scale_factor self.use_ligands = use_ligands self.target_gene = target_gene self.cluster_annot = cluster_annot self.layer = layer self.device = device self.radius = radius self.contact_distance = contact_distance self.spatial_dim = spatial_dim self.tf_ligand_cutoff = tf_ligand_cutoff self.receptor_thresh = receptor_thresh self.xy = pd.DataFrame( adata.obsm["spatial"], index=adata.obs.index, columns=["x", "y"] ) self.species = "mouse" if is_mouse_data(adata) else "human" if regulators is None: if grn is None: assert colinks_path is not None, ( "colinks_path must be provided if grn is None" ) self.grn = RegulatoryFactory( colinks_path=colinks_path, annot=cluster_annot ) else: self.grn = grn self.regulators = self.grn.get_regulators(self.adata, self.target_gene) else: self.regulators = list(regulators) self.grn = None if extra_regulators is not None: self.regulators = list(set(self.regulators) | set(extra_regulators)) # Exclude target gene from regulators to avoid data leak self.regulators = [ i for i in self.regulators if i in adata.var_names and i != self.target_gene ] if self.use_ligands: ligand_mixtures = init_ligands_and_receptors( species=self.species, adata=self.adata, annot=self.cluster_annot, target_gene=self.target_gene, receptor_thresh=self.receptor_thresh, radius=self.radius, contact_distance=self.contact_distance, tf_ligand_cutoff=self.tf_ligand_cutoff, regulators=self.regulators, grn=self.grn, ) self.lr = ligand_mixtures.lr self.ligands = ligand_mixtures.ligands self.receptors = ligand_mixtures.receptors self.tfl_pairs = ligand_mixtures.tfl_pairs self.tfl_regulators = ligand_mixtures.tfl_regulators self.tfl_ligands = ligand_mixtures.tfl_ligands else: self.lr = pd.DataFrame( columns=["ligand", "receptor", "pathway", "signaling"] ) self.lr["pairs"] = self.lr.ligand.values + "$" + self.lr.receptor.values self.ligands = [] self.receptors = [] self.tfl_pairs = [] self.tfl_regulators = [] self.tfl_ligands = [] self.lr_pairs = self.lr["pairs"] self.n_clusters = len(self.adata.obs[self.cluster_annot].unique()) self.modulators = self.regulators + list(self.lr_pairs) + self.tfl_pairs self.modulators_genes = list( np.unique( self.regulators + self.ligands + self.receptors + self.tfl_regulators + self.tfl_ligands ) ) assert len(self.ligands) == len(self.receptors) assert np.isin(self.ligands, self.adata.var_names).all() assert np.isin(self.receptors, self.adata.var_names).all() assert np.isin(self.regulators, self.adata.var_names).all()
[docs] def plot_modulators(self, use_expression=True): """Visualise all modulator genes as a word cloud. Renders a word cloud where each word is a gene that modulates the target (regulators, ligands, receptors, TF-ligand partners). Word size is proportional to mean expression when ``use_expression=True``. Colour encodes modulator category: * **viridis** – ligands (L-R and TF-ligand). * **magma** – receptors. * **rainbow** – transcription factor regulators. Args: use_expression (bool): Scale word size by mean expression across cells. When ``False`` all words are equal size. Default ``True``. Returns: None: Displays the plot via :func:`matplotlib.pyplot.show`. """ if use_expression: # Get mean expression values for each gene genes = list( set( self.regulators + self.ligands + self.tfl_ligands + self.receptors + self.tfl_regulators ) ) expr_values = self.adata.to_df(layer=self.layer)[genes].mean(axis=0) word_freq = {gene: float(expr) for gene, expr in zip(genes, expr_values)} else: word_freq = { reg: 1 for reg in set( self.regulators + self.ligands + self.tfl_ligands + self.receptors + self.tfl_regulators ) } ligand_cmap = plt.get_cmap("viridis") receptor_cmap = plt.get_cmap("magma") regulator_cmap = plt.get_cmap("rainbow") def my_color_func( word, font_size, position, orientation, font_path, random_state ): rnd = random_state.random() # random float in [0.0, 1.0) if word in set(self.ligands).union(self.tfl_ligands): color = ligand_cmap(rnd) return rgb2hex(color[:3]) elif word in set(self.receptors): color = receptor_cmap(rnd) return rgb2hex(color[:3]) elif word in set(self.regulators).union(self.tfl_regulators): color = regulator_cmap(rnd) return rgb2hex(color[:3]) else: return "grey" wordcloud = WordCloud( width=800, height=300, contour_width=1, contour_color="black", background_color="white", color_func=my_color_func, ).generate_from_frequencies(word_freq) plt.figure(figsize=(16, 8)) plt.imshow(wordcloud, interpolation="bilinear", aspect="equal") plt.axis("off") plt.title(f"{self.target_gene} modulators", fontsize=20) plt.show()
[docs] @staticmethod def ligands_receptors_interactions(received_ligands_df, receptor_gex_df): """Compute element-wise ligand × receptor interaction scores. For each paired column ``(ligand_i, receptor_i)`` the interaction score is the Hadamard (element-wise) product of the *received* ligand signal (Gaussian-weighted neighbourhood average) and the cell's own receptor expression:: score_{i,j} = received_ligand_{i,j} × receptor_expr_{i,j} The resulting columns are named ``"LigandGene$ReceptorGene"``, matching the format used throughout the modulator pipeline. Args: received_ligands_df (pd.DataFrame): DataFrame of shape ``(n_cells, n_pairs)`` containing diffused ligand abundances (output of :func:`received_ligands`). receptor_gex_df (pd.DataFrame): DataFrame of the same shape containing receptor expression values for each paired receptor. Returns: pd.DataFrame: Interaction score matrix of shape ``(n_cells, n_pairs)`` with columns in ``"Lig$Rec"`` format. """ assert isinstance(received_ligands_df, pd.DataFrame) assert isinstance(receptor_gex_df, pd.DataFrame) assert received_ligands_df.index.equals(receptor_gex_df.index) assert received_ligands_df.shape[1] == receptor_gex_df.shape[1] _received_ligands = received_ligands_df.values _self_receptor_expression = receptor_gex_df.values lr_interactions = _received_ligands * _self_receptor_expression return pd.DataFrame( lr_interactions, columns=[ i[0] + "$" + i[1] for i in zip(received_ligands_df.columns, receptor_gex_df.columns) ], index=receptor_gex_df.index, )
[docs] @staticmethod def ligand_regulators_interactions(received_ligands_df, regulator_gex_df): """Compute element-wise ligand × TF (NicheNet) interaction scores. Analogous to :meth:`ligands_receptors_interactions` but for the NicheNet TF-ligand axis. For each ``(ligand_i, tf_i)`` pair the score is the product of the received ligand signal and the TF's expression in the same cell:: score_{i,j} = received_ligand_{i,j} × tf_expr_{i,j} Columns are named ``"LigandGene#TFGene"``. Args: received_ligands_df (pd.DataFrame): Received ligand signals, shape ``(n_cells, n_pairs)``. regulator_gex_df (pd.DataFrame): TF expression values of the same shape. Returns: pd.DataFrame: Interaction score matrix of shape ``(n_cells, n_pairs)`` with columns in ``"Lig#TF"`` format. """ assert isinstance(received_ligands_df, pd.DataFrame) assert isinstance(regulator_gex_df, pd.DataFrame) assert received_ligands_df.index.equals(regulator_gex_df.index) assert received_ligands_df.shape[1] == regulator_gex_df.shape[1] _received_ligands = received_ligands_df.values _self_regulator_expression = regulator_gex_df.values ltf_interactions = _received_ligands * _self_regulator_expression return pd.DataFrame( ltf_interactions, columns=[ i[0] + "#" + i[1] for i in zip(received_ligands_df.columns, regulator_gex_df.columns) ], index=regulator_gex_df.index, )
[docs] @staticmethod def check_LR_properties(adata, layer): """Retrieve expression counts and optional cell-type thresholds. A lightweight helper used by :meth:`init_data` to unpack the two artefacts needed for L-R filtering: the full expression DataFrame and the per-cell-type expression threshold mask stored in ``adata.uns["cell_thresholds"]``. Args: adata (AnnData): The dataset to query. layer (str): Layer key for expression values. Returns: tuple[pd.DataFrame, pd.DataFrame | None]: ``(counts_df, cell_thresholds)``. ``cell_thresholds`` is ``None`` when the key is absent from ``adata.uns``. """ counts_df = adata.to_df(layer=layer) cell_thresholds = adata.uns.get("cell_thresholds", None) if cell_thresholds is None: print("warning: cell_thresholds not found in adata.uns") return counts_df, cell_thresholds
[docs] def init_data(self): """Build all training matrices and cache them on ``self``. This method orchestrates the full feature-engineering pipeline and **must** be called (implicitly via :meth:`fit`) before :meth:`get_betas` can be used. It performs the following steps in order: 1. Compute (or reuse cached) *received ligand* signals and store them in ``adata.uns["received_ligands"]`` / ``"received_ligands_tfl"``. 2. Compute L-R interaction scores and store them in ``adata.uns["ligand_receptor"]``. 3. Compute TF-ligand interaction scores and store them in ``adata.uns["ligand_regulator"]``. 4. Build (or reuse cached) spatial neighbourhood image tensors ``adata.obsm["spatial_maps"]``. 5. Build (or reuse cached) spatial context features ``adata.obsm["spatial_features"]`` and normalise them with :class:`~sklearn.preprocessing.MinMaxScaler`. 6. Construct the training DataFrame ``self.train_df`` combining all feature groups. 7. Filter out any L-R / TF-ligand pairs that produced zero-variance features. 8. Update ``self.modulators`` and ``self.modulators_genes`` to reflect surviving pairs. Side effects: Populates ``self.spatial_maps``, ``self.spatial_features``, ``self.train_df``, ``self.xy``, ``self.xy_df``, ``self.lr_pairs``, ``self.tfl_pairs``, ``self.ligands``, ``self.receptors``, ``self.tfl_ligands``, ``self.tfl_regulators``. Returns: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ``(sp_maps, X, y, cluster_labels)`` ready for the training loop in :meth:`fit`. * ``sp_maps`` – shape ``(n_cells, n_clusters, D, D)``. * ``X`` – design matrix of shape ``(n_cells, n_modulators)``. * ``y`` – target expression vector of shape ``(n_cells,)``. * ``cluster_labels`` – integer cluster label per cell. """ lr_info = self.check_LR_properties(self.adata, self.layer) counts_df, cell_thresholds = lr_info if not ( ("received_ligands" in self.adata.uns.keys()) | ("received_ligands_tfl" in self.adata.uns.keys()) ): print(f"Initializing received ligands with {self.scale_factor}") self.adata = init_received_ligands( self.adata, radius=self.radius, contact_distance=self.contact_distance, cell_threshes=cell_thresholds, scale_factor=self.scale_factor ) if len(self.lr["pairs"]) > 0: self.adata.uns["ligand_receptor"] = self.ligands_receptors_interactions( self.adata.uns["received_ligands"][self.ligands], get_filtered_df(counts_df, cell_thresholds, self.receptors)[ self.receptors ], ) else: self.adata.uns["received_ligands"] = pd.DataFrame( index=self.adata.obs.index ) self.adata.uns["ligand_receptor"] = pd.DataFrame(index=self.adata.obs.index) if len(self.tfl_pairs) > 0: self.adata.uns["ligand_regulator"] = self.ligand_regulators_interactions( self.adata.uns["received_ligands_tfl"][self.tfl_ligands], self.adata.to_df(layer=self.layer)[self.tfl_regulators], ) else: self.adata.uns["ligand_regulator"] = pd.DataFrame( index=self.adata.obs.index ) self.xy = np.array(self.adata.obsm["spatial"]) cluster_labels = np.array(self.adata.obs[self.cluster_annot]) self.xy_df = pd.DataFrame( self.xy, columns=["x", "y"], index=self.adata.obs.index ) if not "spatial_maps" in self.adata.obsm.keys(): self.spatial_maps = xyc2spatial_fast( xyc=np.column_stack([self.xy, cluster_labels]), m=self.spatial_dim, n=self.spatial_dim, ) self.adata.obsm["spatial_maps"] = self.spatial_maps else: self.spatial_maps = self.adata.obsm["spatial_maps"] self.train_df = ( self.adata.to_df(layer=self.layer)[[self.target_gene] + self.regulators] .join(self.adata.uns["ligand_receptor"]) .join(self.adata.uns["ligand_regulator"]) ) if not "spatial_features" in self.adata.obsm.keys(): self.spatial_features = create_spatial_features( self.adata.obsm["spatial"][:, 0], self.adata.obsm["spatial"][:, 1], self.adata.obs[self.cluster_annot], self.adata.obs.index, radius=self.radius, ) self.adata.obsm["spatial_features"] = self.spatial_features.copy() else: self.spatial_features = self.adata.obsm["spatial_features"] self.spatial_features = pd.DataFrame( MinMaxScaler().fit_transform(self.spatial_features.values), columns=self.spatial_features.columns, index=self.spatial_features.index, ) self.lr_pairs = self.lr_pairs[self.lr_pairs.isin(self.train_df.columns)] self.tfl_pairs = [i for i in self.tfl_pairs if i in self.train_df.columns] self.ligands = [] self.receptors = [] self.tfl_regulators = [] self.tfl_ligands = [] for i in self.lr_pairs: lig, rec = i.split("$") self.ligands.append(lig) self.receptors.append(rec) for i in self.tfl_pairs: lig, reg = i.split("#") self.tfl_ligands.append(lig) self.tfl_regulators.append(reg) self.modulators = self.regulators + list(self.lr_pairs) + self.tfl_pairs self.modulators_genes = list( np.unique( self.regulators + self.ligands + self.receptors + self.tfl_regulators + self.tfl_ligands ) ) assert len(self.ligands) == len(self.receptors) X = self.train_df.drop(columns=[self.target_gene]).values y = self.train_df[self.target_gene].values sp_maps = self.spatial_maps assert sp_maps.shape[0] == X.shape[0] == y.shape[0] == len(cluster_labels) return sp_maps, X, y, cluster_labels
[docs] @torch.no_grad() def get_betas(self): """Extract per-cell spatial beta coefficients from trained models. Iterates over each unique cluster label, forwards the corresponding spatial neighbourhood images and spatial context features through the cluster's :class:`~SpaceTravLR.models.pixel_attention.CellularNicheNetwork`, and collects the returned beta vectors. Cells whose cluster has no trained model (e.g. clusters with R² < threshold) receive an all-zeros beta row. The beta matrix has columns: * ``"beta0"`` – the learned intercept / baseline expression. * ``"beta_<modulator>"`` – one column per entry in ``self.modulators`` following the order: regulators → L-R pairs → TF-ligand pairs. This method is decorated with ``@torch.no_grad()`` so no gradients are computed during inference. Returns: pd.DataFrame: Shape ``(n_cells, 1 + n_modulators)`` indexed by ``adata.obs.index``. Rows are sorted to match the original cell order regardless of the per-cluster iteration order. """ index_tracker = [] betas = [] for cluster_target in np.unique(self.cluster_labels): mask = self.cluster_labels == cluster_target indices = self.cell_indices[mask] index_tracker.extend(indices) if cluster_target not in self.models: print(f"No model found for {cluster_target}") b = np.zeros((len(indices), (len(self.modulators) + 1))) else: cluster_sp_maps = torch.from_numpy( self.sp_maps[mask][:, cluster_target : cluster_target + 1, :, :] ).float() spf = torch.from_numpy(self.spatial_features.values[mask]).float() b = ( self.models[cluster_target] .get_betas(cluster_sp_maps.to(self.device), spf.to(self.device)) .cpu() .numpy() ) betas.extend(b) return pd.DataFrame( betas, index=index_tracker, columns=["beta0"] + ["beta_" + i for i in self.modulators], ).reindex(self.adata.obs.index)
@property def betadata(self): """Alias for :meth:`get_betas` kept for backward compatibility. Returns: pd.DataFrame: Per-cell spatial beta coefficients (see :meth:`get_betas` for full description). """ return self.get_betas()
[docs] def fit( self, num_epochs=100, threshold_lambda=1e-6, learning_rate=5e-3, batch_size=512, pbar=None, estimator="lasso", vision_model="cnn", score_threshold=0.2, l1_reg=1e-9, skip_clusters=None, ): """Train per-cluster spatial cellular programme models. For each unique cluster label the method: 1. Fits a *seed estimator* (Group Lasso, ARD, or Bayesian Ridge) on the cluster's cells to obtain initial ``_betas`` anchors. 2. If the seed R² exceeds ``score_threshold``, trains a :class:`~SpaceTravLR.models.pixel_attention.CellularNicheNetwork` or :class:`~SpaceTravLR.models.pixel_attention.CellularViT` with those anchors using Adam + MSE loss. 3. If the neural-network R² after training remains below ``score_threshold``, the model anchors are zeroed out so the gene's spatial betas collapse to the baseline (intercept only). The seed estimator types are: * ``"lasso"`` *(default)* – Group Lasso with feature groups ``(regulators, L-R pairs, TF-ligand pairs)`` and ``threshold_lambda`` as the group regularisation strength. * ``"bayesian"`` – :class:`~sklearn.linear_model.BayesianRidge`. * ``"ard"`` – :class:`~sklearn.linear_model.ARDRegression`. Scales as *O(n²)* in samples; avoid for large datasets. Args: num_epochs (int): Number of neural-network training epochs per cluster. Set to ``0`` to skip neural training (seed betas only). Default ``100``. threshold_lambda (float): Group-Lasso regularisation strength (and ARD ``threshold_lambda``). Default ``1e-6``. learning_rate (float): Adam learning rate. Default ``5e-3``. batch_size (int): Mini-batch size for the DataLoader. Default ``512``. pbar: An :mod:`enlighten` progress-bar counter. When ``None`` a new manager is created automatically. estimator (str): Seed estimator type; one of ``"lasso"``, ``"bayesian"``, ``"ard"``. Default ``"lasso"``. vision_model (str): Neural architecture; one of ``"cnn"`` (:class:`~SpaceTravLR.models.pixel_attention.CellularNicheNetwork`) or ``"transformer"`` (:class:`~SpaceTravLR.models.pixel_attention.CellularViT`). Default ``"cnn"``. score_threshold (float): Minimum seed R² required to proceed with neural training. Clusters below this threshold receive zero-anchor models. Default ``0.2``. l1_reg (float): Element-wise L1 penalty applied on top of the Group Lasso group penalty. Default ``1e-9``. skip_clusters (list[int] | None): Cluster integer labels to skip entirely (progress bar is still advanced). Default ``None``. Returns: None: Populates ``self.models``, ``self.scores``, and ``self.loss_dict`` in-place. """ sp_maps, X, y, cluster_labels = self.init_data() if skip_clusters is None: skip_clusters = [] assert estimator in ["lasso", "bayesian", "ard"] assert vision_model in ["cnn", "transformer"] self.estimator = estimator self.vision_model = vision_model self.models = {} self.Xn = X self.yn = y self.sp_maps = sp_maps self.cell_indices = self.adata.obs.index.copy() self.cluster_labels = cluster_labels if pbar is None: manager = enlighten.get_manager() pbar = manager.counter( total=sp_maps.shape[0] * num_epochs, desc="Estimating Spatial Betas", unit="cells", color="green", auto_refresh=True, ) if num_epochs: print(f"Fitting {self.target_gene} with {len(self.modulators)} modulators") print(f"\t{len(self.regulators)} Transcription Factors") print(f"\t{len(self.lr_pairs)} Ligand-Receptor Pairs") print(f"\t{len(self.tfl_pairs)} TranscriptionFactor-Ligand Pairs") self.scores = {} self.loss_dict = {} for cluster in np.unique(cluster_labels): if int(cluster) in skip_clusters: pbar.update( num_epochs * len(self.cell_indices[cluster_labels == cluster]) ) continue mask = cluster_labels == cluster X_cell, y_cell = self.Xn[mask], self.yn[mask] if self.estimator == "ard": """ ARD allocates a n_samples * n_samples matrix so isn't very scalable """ m = ARDRegression(threshold_lambda=threshold_lambda) m.fit(X_cell, y_cell) y_pred = m.predict(X_cell) r2 = r2_score(y_cell, y_pred) _betas = np.hstack([m.intercept_, m.coef_]) coefs = None elif self.estimator == "bayesian": m = BayesianRidge() m.fit(X_cell, y_cell) y_pred = m.predict(X_cell) r2 = r2_score(y_cell, y_pred) _betas = np.hstack([m.intercept_, m.coef_]) elif self.estimator == "lasso": groups = ( [1] * len(self.regulators) + [2] * len(self.lr_pairs) + [3] * len(self.tfl_pairs) ) groups = np.array(groups) gl = GroupLasso( groups=groups, group_reg=threshold_lambda, l1_reg=l1_reg, frobenius_lipschitz=True, scale_reg="inverse_group_size", warm_start=True, random_state=42, # subsampling_scheme=1, supress_warning=True, n_iter=1500, # warm_start=True, tol=1e-5, ) gl.fit(X_cell, y_cell) y_pred = gl.predict(X_cell) coefs = gl.coef_.flatten() _betas = np.hstack([gl.intercept_, coefs]) r2 = r2_score(y_cell, y_pred) self.scores[cluster] = r2 if r2 < 0.15: _model = CellularNicheNetwork( n_modulators=len(self.modulators), anchors=_betas * 0, spatial_dim=self.spatial_dim, n_clusters=self.n_clusters, ).to(self.device) self.models[cluster] = _model print(f"{cluster}: x.xxx* | {r2:.4f}") pbar.update(len(X_cell) * num_epochs) continue loader = DataLoader( RotatedTensorDataset( sp_maps[mask], X_cell, y_cell, cluster, self.spatial_features.iloc[mask].values, rotate_maps=True, ), batch_size=batch_size, shuffle=True, ) assert _betas.shape[0] == len(self.modulators) + 1 if self.vision_model == "cnn": model = CellularNicheNetwork( n_modulators=len(self.modulators), anchors=_betas, spatial_dim=self.spatial_dim, n_clusters=self.n_clusters, ).to(self.device) elif self.vision_model == "transformer": model = CellularViT( n_modulators=len(self.modulators), anchors=_betas, spatial_dim=self.spatial_dim, n_clusters=self.n_clusters, ).to(self.device) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, weight_decay=0 ) self.loss_dict[cluster] = [] for epoch in range(num_epochs): model.train() epoch_loss = 0 all_y_true = [] all_y_pred = [] for batch in loader: spatial_maps, inputs, targets, spatial_features = [ b.to(device) for b in batch ] optimizer.zero_grad() outputs = model(spatial_maps, inputs, spatial_features) loss = criterion(outputs, targets) # loss += torch.mean(outputs.mean(0) - model.anchors) * 1e-5 loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=2.0, norm_type=2 ) optimizer.step() epoch_loss += loss.item() all_y_true.extend(targets.cpu().detach().numpy()) all_y_pred.extend(outputs.cpu().detach().numpy()) pbar.desc = f"{self.target_gene} | {cluster + 1}/{self.n_clusters}" pbar.update(len(targets)) self.loss_dict[cluster].append(loss.item()) if num_epochs: score = r2_score(all_y_true, all_y_pred) if score < score_threshold: # no point in predicting betas if we do it poorly model.anchors = model.anchors * 0.0 print(f"{cluster}: x.xxxx | {r2:.4f}") else: print(f"{cluster}: {score:.4f} | {r2:.4f}") self.models[cluster] = model
[docs] def export(self, save_dir="./models"): """Serialise the trained estimator to disk. PyTorch :class:`~torch.nn.Module` objects are not directly picklable in all configurations. This method works around that by converting each cluster model to a plain ``dict`` of ``{state_dict, anchors}`` before pickling the estimator, then saves the result as ``<save_dir>/<target_gene>_estimator.pkl``. Usage:: estimator.fit(num_epochs=50) estimator.export("./output/models") # → ./output/models/Myc_estimator.pkl Args: save_dir (str): Directory path to save the ``.pkl`` file. Created automatically if it does not exist. Default ``"./models"``. Returns: None """ # Create a copy of self that we can modify export_obj = copy.copy(self) # Extract state dicts and anchors from models model_states = {} for cluster, model in self.models.items(): if model is None: model_states[cluster] = None else: model_states[cluster] = { "state_dict": model.state_dict(), "anchors": model.anchors, } # Replace model objects with None before pickling export_obj.models = model_states # Save the modified object os.makedirs(save_dir, exist_ok=True) with open( os.path.join(save_dir, f"{self.target_gene}_estimator.pkl"), "wb" ) as f: pickle.dump(export_obj, f)
[docs] def load(self, path): """Restore a previously exported estimator from disk. Reads the pickle file written by :meth:`export`, copies all non-model attributes back onto ``self``, then reconstructs each per-cluster :class:`~SpaceTravLR.models.pixel_attention.CellularNicheNetwork` from its saved state dict and anchors. Usage:: estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn) estimator.load("./output/models/Myc_estimator.pkl") betas = estimator.get_betas() Args: path (str): Absolute or relative path to the ``.pkl`` file produced by :meth:`export`. Returns: None: Modifies ``self`` in-place. """ with open(path, "rb") as f: loaded = pickle.load(f) # Copy all attributes except models for attr, val in loaded.__dict__.items(): if attr != "models": setattr(self, attr, val) # Reconstruct models from state dicts self.models = {} for cluster, state in loaded.models.items(): model = CellularNicheNetwork( n_modulators=len(self.modulators), anchors=state["anchors"], spatial_dim=self.spatial_dim, n_clusters=self.n_clusters, ).to(self.device) model.load_state_dict(state) self.models[cluster] = model