Model

SpatialCellularProgramsEstimator is the core per-gene fitting engine used internally by SpaceTravLR. It fits a spatially-aware regulatory model for a single target gene using three classes of input features:

Input feature groups

Group

Description

Column format

Regulators (TFs)

Transcription-factor expression from a GRN lookup

GeneSymbol

Ligand-receptor pairs

Gaussian-diffused ligand signal x receptor expression

LigandGene$ReceptorGene

Ligand-TF pairs (NicheNet)

Diffused ligand signal x TF expression (NicheNet-filtered)

LigandGene#TFGene

The model produces per-cell spatial beta coefficients — one column per modulator plus an intercept — describing how regulatory influence varies across the tissue.

Training pipeline overview

For every cell-type cluster the estimator:

  1. Fits a seed linear model (Group Lasso / Bayesian Ridge / ARD) on the cluster cells to obtain coefficient anchors.

  2. Constructs a 2-D spatial neighbourhood image (side spatial_dim) capturing local cell-type densities.

  3. Trains a CellularNicheNetwork (CNN) or CellularViT (ViT) conditioned on those anchors and neighbourhood images.

  4. If the final R² falls below score_threshold, zeroes the anchors so the gene betas collapse to the global baseline.

from SpaceTravLR.models.parallel_estimators import SpatialCellularProgramsEstimator
from SpaceTravLR.tools.network import RegulatoryFactory

grn = RegulatoryFactory(colinks_path="colinks.csv", annot="cell_type_int")

estimator = SpatialCellularProgramsEstimator(
    adata,
    target_gene="Myc",
    grn=grn,
    radius=150,
    spatial_dim=64,
)

# Train one model per cell-type cluster
estimator.fit(num_epochs=80, estimator="lasso", score_threshold=0.2)

# Retrieve per-cell spatial betas -- shape (n_cells, 1 + n_modulators)
betas = estimator.get_betas()

# Visualise regulators as a colour-coded word cloud
estimator.plot_modulators()

# Persist and reload
estimator.export("./output/models")
estimator.load("./output/models/Myc_estimator.pkl")
class SpaceTravLR.models.parallel_estimators.SpatialCellularProgramsEstimator(adata, target_gene, spatial_dim=64, cluster_annot='cell_type_int', layer='imputed_count', radius=100, contact_distance=30, use_ligands=True, tf_ligand_cutoff=0.01, receptor_thresh=0.1, regulators=None, grn=None, colinks_path=None, scale_factor=1, extra_regulators=None)[source]

Bases: object

Per-gene spatial regression estimator with ligand-receptor and GRN features.

SpatialCellularProgramsEstimator fits a spatially-aware gene-regulatory model for a single target gene. For every annotated cell-type cluster it trains a CellularNicheNetwork (CNN) or CellularViT (ViT) that takes three inputs:

  1. Spatial neighbourhood map – a 2-D density image of each cell-type derived from adata.obsm["spatial"].

  2. Regulatory features – expression of transcription factors inferred from a supplied GRN, along with ligand-receptor interaction scores and ligand-TF interaction scores computed from the CellChat database.

  3. Spatial context features – per-cell neighbour-count vectors (one column per cell type within radius microns).

The model outputs per-cell spatial beta coefficients (one per modulator + intercept) describing how strongly each regulator, L-R pair or TF-ligand pair influences the target gene at that spatial location.

Pipeline overview:

estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn)
estimator.fit(num_epochs=50)
betas = estimator.get_betas()   # DataFrame (cells × modulators)
estimator.export("./models")
Attributes:

adata (AnnData): The spatial single-cell dataset. target_gene (str): The gene being modelled. regulators (list[str]): Transcription-factor regulators of the target. ligands (list[str]): Ligand genes from active L-R pairs. receptors (list[str]): Receptor genes paired with ligands. tfl_ligands (list[str]): Ligands from TF-ligand (NicheNet) pairs. tfl_regulators (list[str]): TFs paired with tfl_ligands. lr_pairs (pd.Series): Active ligand-receptor pair identifiers

(format "LigandGene$ReceptorGene").

tfl_pairs (list[str]): Active TF-ligand pair identifiers

(format "LigandGene#TFGene").

modulators (list[str]): Ordered concatenation of regulators, L-R

pairs, and TF-ligand pairs – defines the column order of the beta matrix.

modulators_genes (list[str]): Unique gene names spanning all

modulator categories.

models (dict[int, CellularNicheNetwork]): Trained per-cluster neural

networks (available after fit()).

scores (dict[int, float]): Per-cluster R² scores on the training set

(available after fit()).

loss_dict (dict[int, list[float]]): Per-batch MSE losses collected

during training (available after fit()).

spatial_maps (np.ndarray): Cached neighbourhood image tensors of

shape (n_cells, n_clusters, spatial_dim, spatial_dim).

spatial_features (pd.DataFrame): Min-max scaled neighbour-count

features (shape n_cells × n_clusters).

xy (np.ndarray): Spatial coordinates array of shape (n_cells, 2). device (torch.device): PyTorch compute device (CUDA / MPS / CPU). species (str): "mouse" or "human" inferred from gene names.

__init__(adata, target_gene, spatial_dim=64, cluster_annot='cell_type_int', layer='imputed_count', radius=100, contact_distance=30, use_ligands=True, tf_ligand_cutoff=0.01, receptor_thresh=0.1, regulators=None, grn=None, colinks_path=None, scale_factor=1, extra_regulators=None)[source]

Initialise the estimator and resolve regulatory inputs.

The constructor performs three main steps:

  1. GRN lookup – fetch the list of transcription-factor regulators for target_gene either from an explicit regulators list or via grn.get_regulators.

  2. Ligand-receptor discovery – query the CellChat database to enumerate L-R pairs and NicheNet TF-ligand pairs that are relevant to target_gene (skipped when use_ligands=False).

  3. Modulator assembly – build the ordered modulators list that defines the column ordering of the beta coefficient matrix.

Args:
adata (AnnData): Spatial dataset. Must contain

adata.obsm["spatial"], the requested layer, and cluster_annot in adata.obs.

target_gene (str): Name of the gene to model. Must be present in

adata.var_names.

spatial_dim (int): Side length (pixels) of the square spatial

neighbourhood image used as CNN input. Default 64.

cluster_annot (str): adata.obs column holding integer-encoded

cell-type labels. Default "cell_type_int".

layer (str): adata.layers key for gene-expression values.

Default "imputed_count".

radius (float): Search radius (same spatial units as

adata.obsm["spatial"]) for secreted ligand diffusion. Default 100.

contact_distance (float): Search radius for contact signalling

(juxtacrine / cell-surface). Default 30.

use_ligands (bool): When False, skip all ligand-receptor and

TF-ligand feature computation. Default True.

tf_ligand_cutoff (float): Minimum NicheNet ligand-target score

required to include a TF-ligand pair. Default 0.01.

receptor_thresh (float): Reserved threshold parameter (currently

unused in filtering). Default 0.1.

regulators (list[str] | None): Explicit list of TF regulators.

When supplied, grn and colinks_path are ignored.

grn (RegulatoryFactory | None): Pre-built GRN object exposing a

get_regulators(adata, gene) method. Mutually exclusive with colinks_path.

colinks_path (str | None): Path to a co-link CSV used to

construct a RegulatoryFactory internally. Required when both grn and regulators are None.

scale_factor (float): Multiplicative weight applied to Gaussian

kernel values during ligand diffusion. Default 1.

extra_regulators (list[str] | None): Additional gene names to

append to the regulator list after GRN lookup.

plot_modulators(use_expression=True)[source]

Visualise all modulator genes as a word cloud.

Renders a word cloud where each word is a gene that modulates the target (regulators, ligands, receptors, TF-ligand partners). Word size is proportional to mean expression when use_expression=True. Colour encodes modulator category:

  • viridis – ligands (L-R and TF-ligand).

  • magma – receptors.

  • rainbow – transcription factor regulators.

Args:
use_expression (bool): Scale word size by mean expression across

cells. When False all words are equal size. Default True.

Returns:

None: Displays the plot via matplotlib.pyplot.show().

static ligands_receptors_interactions(received_ligands_df, receptor_gex_df)[source]

Compute element-wise ligand × receptor interaction scores.

For each paired column (ligand_i, receptor_i) the interaction score is the Hadamard (element-wise) product of the received ligand signal (Gaussian-weighted neighbourhood average) and the cell’s own receptor expression:

score_{i,j} = received_ligand_{i,j} × receptor_expr_{i,j}

The resulting columns are named "LigandGene$ReceptorGene", matching the format used throughout the modulator pipeline.

Args:
received_ligands_df (pd.DataFrame): DataFrame of shape

(n_cells, n_pairs) containing diffused ligand abundances (output of received_ligands()).

receptor_gex_df (pd.DataFrame): DataFrame of the same shape

containing receptor expression values for each paired receptor.

Returns:

pd.DataFrame: Interaction score matrix of shape (n_cells, n_pairs) with columns in "Lig$Rec" format.

static ligand_regulators_interactions(received_ligands_df, regulator_gex_df)[source]

Compute element-wise ligand × TF (NicheNet) interaction scores.

Analogous to ligands_receptors_interactions() but for the NicheNet TF-ligand axis. For each (ligand_i, tf_i) pair the score is the product of the received ligand signal and the TF’s expression in the same cell:

score_{i,j} = received_ligand_{i,j} × tf_expr_{i,j}

Columns are named "LigandGene#TFGene".

Args:
received_ligands_df (pd.DataFrame): Received ligand signals,

shape (n_cells, n_pairs).

regulator_gex_df (pd.DataFrame): TF expression values of the

same shape.

Returns:

pd.DataFrame: Interaction score matrix of shape (n_cells, n_pairs) with columns in "Lig#TF" format.

static check_LR_properties(adata, layer)[source]

Retrieve expression counts and optional cell-type thresholds.

A lightweight helper used by init_data() to unpack the two artefacts needed for L-R filtering: the full expression DataFrame and the per-cell-type expression threshold mask stored in adata.uns["cell_thresholds"].

Args:

adata (AnnData): The dataset to query. layer (str): Layer key for expression values.

Returns:
tuple[pd.DataFrame, pd.DataFrame | None]:

(counts_df, cell_thresholds). cell_thresholds is None when the key is absent from adata.uns.

init_data()[source]

Build all training matrices and cache them on self.

This method orchestrates the full feature-engineering pipeline and must be called (implicitly via fit()) before get_betas() can be used. It performs the following steps in order:

  1. Compute (or reuse cached) received ligand signals and store them in adata.uns["received_ligands"] / "received_ligands_tfl".

  2. Compute L-R interaction scores and store them in adata.uns["ligand_receptor"].

  3. Compute TF-ligand interaction scores and store them in adata.uns["ligand_regulator"].

  4. Build (or reuse cached) spatial neighbourhood image tensors adata.obsm["spatial_maps"].

  5. Build (or reuse cached) spatial context features adata.obsm["spatial_features"] and normalise them with MinMaxScaler.

  6. Construct the training DataFrame self.train_df combining all feature groups.

  7. Filter out any L-R / TF-ligand pairs that produced zero-variance features.

  8. Update self.modulators and self.modulators_genes to reflect surviving pairs.

Side effects:

Populates self.spatial_maps, self.spatial_features, self.train_df, self.xy, self.xy_df, self.lr_pairs, self.tfl_pairs, self.ligands, self.receptors, self.tfl_ligands, self.tfl_regulators.

Returns:
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

(sp_maps, X, y, cluster_labels) ready for the training loop in fit().

  • sp_maps – shape (n_cells, n_clusters, D, D).

  • X – design matrix of shape (n_cells, n_modulators).

  • y – target expression vector of shape (n_cells,).

  • cluster_labels – integer cluster label per cell.

get_betas()[source]

Extract per-cell spatial beta coefficients from trained models.

Iterates over each unique cluster label, forwards the corresponding spatial neighbourhood images and spatial context features through the cluster’s CellularNicheNetwork, and collects the returned beta vectors. Cells whose cluster has no trained model (e.g. clusters with R² < threshold) receive an all-zeros beta row.

The beta matrix has columns:

  • "beta0" – the learned intercept / baseline expression.

  • "beta_<modulator>" – one column per entry in self.modulators following the order: regulators → L-R pairs → TF-ligand pairs.

This method is decorated with @torch.no_grad() so no gradients are computed during inference.

Returns:

pd.DataFrame: Shape (n_cells, 1 + n_modulators) indexed by adata.obs.index. Rows are sorted to match the original cell order regardless of the per-cluster iteration order.

property betadata

Alias for get_betas() kept for backward compatibility.

Returns:

pd.DataFrame: Per-cell spatial beta coefficients (see get_betas() for full description).

fit(num_epochs=100, threshold_lambda=1e-06, learning_rate=0.005, batch_size=512, pbar=None, estimator='lasso', vision_model='cnn', score_threshold=0.2, l1_reg=1e-09, skip_clusters=None)[source]

Train per-cluster spatial cellular programme models.

For each unique cluster label the method:

  1. Fits a seed estimator (Group Lasso, ARD, or Bayesian Ridge) on the cluster’s cells to obtain initial _betas anchors.

  2. If the seed R² exceeds score_threshold, trains a CellularNicheNetwork or CellularViT with those anchors using Adam + MSE loss.

  3. If the neural-network R² after training remains below score_threshold, the model anchors are zeroed out so the gene’s spatial betas collapse to the baseline (intercept only).

The seed estimator types are:

  • "lasso" (default) – Group Lasso with feature groups (regulators, L-R pairs, TF-ligand pairs) and threshold_lambda as the group regularisation strength.

  • "bayesian"BayesianRidge.

  • "ard"ARDRegression. Scales as O(n²) in samples; avoid for large datasets.

Args:
num_epochs (int): Number of neural-network training epochs per

cluster. Set to 0 to skip neural training (seed betas only). Default 100.

threshold_lambda (float): Group-Lasso regularisation strength

(and ARD threshold_lambda). Default 1e-6.

learning_rate (float): Adam learning rate. Default 5e-3. batch_size (int): Mini-batch size for the DataLoader.

Default 512.

pbar: An enlighten progress-bar counter. When None

a new manager is created automatically.

estimator (str): Seed estimator type; one of "lasso",

"bayesian", "ard". Default "lasso".

vision_model (str): Neural architecture; one of "cnn"

(CellularNicheNetwork) or "transformer" (CellularViT). Default "cnn".

score_threshold (float): Minimum seed R² required to proceed

with neural training. Clusters below this threshold receive zero-anchor models. Default 0.2.

l1_reg (float): Element-wise L1 penalty applied on top of the

Group Lasso group penalty. Default 1e-9.

skip_clusters (list[int] | None): Cluster integer labels to

skip entirely (progress bar is still advanced). Default None.

Returns:

None: Populates self.models, self.scores, and self.loss_dict in-place.

export(save_dir='./models')[source]

Serialise the trained estimator to disk.

PyTorch Module objects are not directly picklable in all configurations. This method works around that by converting each cluster model to a plain dict of {state_dict, anchors} before pickling the estimator, then saves the result as <save_dir>/<target_gene>_estimator.pkl.

Usage:

estimator.fit(num_epochs=50)
estimator.export("./output/models")
# → ./output/models/Myc_estimator.pkl
Args:
save_dir (str): Directory path to save the .pkl file.

Created automatically if it does not exist. Default "./models".

Returns:

None

load(path)[source]

Restore a previously exported estimator from disk.

Reads the pickle file written by export(), copies all non-model attributes back onto self, then reconstructs each per-cluster CellularNicheNetwork from its saved state dict and anchors.

Usage:

estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn)
estimator.load("./output/models/Myc_estimator.pkl")
betas = estimator.get_betas()
Args:
path (str): Absolute or relative path to the .pkl file

produced by export().

Returns:

None: Modifies self in-place.