Model¶
SpatialCellularProgramsEstimator
is the core per-gene fitting engine used internally by
SpaceTravLR. It fits a spatially-aware
regulatory model for a single target gene using three classes of input
features:
Group |
Description |
Column format |
|---|---|---|
Regulators (TFs) |
Transcription-factor expression from a GRN lookup |
|
Ligand-receptor pairs |
Gaussian-diffused ligand signal x receptor expression |
|
Ligand-TF pairs (NicheNet) |
Diffused ligand signal x TF expression (NicheNet-filtered) |
|
The model produces per-cell spatial beta coefficients — one column per modulator plus an intercept — describing how regulatory influence varies across the tissue.
Training pipeline overview
For every cell-type cluster the estimator:
Fits a seed linear model (Group Lasso / Bayesian Ridge / ARD) on the cluster cells to obtain coefficient anchors.
Constructs a 2-D spatial neighbourhood image (side
spatial_dim) capturing local cell-type densities.Trains a CellularNicheNetwork (CNN) or CellularViT (ViT) conditioned on those anchors and neighbourhood images.
If the final R² falls below
score_threshold, zeroes the anchors so the gene betas collapse to the global baseline.
from SpaceTravLR.models.parallel_estimators import SpatialCellularProgramsEstimator
from SpaceTravLR.tools.network import RegulatoryFactory
grn = RegulatoryFactory(colinks_path="colinks.csv", annot="cell_type_int")
estimator = SpatialCellularProgramsEstimator(
adata,
target_gene="Myc",
grn=grn,
radius=150,
spatial_dim=64,
)
# Train one model per cell-type cluster
estimator.fit(num_epochs=80, estimator="lasso", score_threshold=0.2)
# Retrieve per-cell spatial betas -- shape (n_cells, 1 + n_modulators)
betas = estimator.get_betas()
# Visualise regulators as a colour-coded word cloud
estimator.plot_modulators()
# Persist and reload
estimator.export("./output/models")
estimator.load("./output/models/Myc_estimator.pkl")
- class SpaceTravLR.models.parallel_estimators.SpatialCellularProgramsEstimator(adata, target_gene, spatial_dim=64, cluster_annot='cell_type_int', layer='imputed_count', radius=100, contact_distance=30, use_ligands=True, tf_ligand_cutoff=0.01, receptor_thresh=0.1, regulators=None, grn=None, colinks_path=None, scale_factor=1, extra_regulators=None)[source]¶
Bases:
objectPer-gene spatial regression estimator with ligand-receptor and GRN features.
SpatialCellularProgramsEstimatorfits a spatially-aware gene-regulatory model for a single target gene. For every annotated cell-type cluster it trains aCellularNicheNetwork(CNN) orCellularViT(ViT) that takes three inputs:Spatial neighbourhood map – a 2-D density image of each cell-type derived from
adata.obsm["spatial"].Regulatory features – expression of transcription factors inferred from a supplied GRN, along with ligand-receptor interaction scores and ligand-TF interaction scores computed from the CellChat database.
Spatial context features – per-cell neighbour-count vectors (one column per cell type within
radiusmicrons).
The model outputs per-cell spatial beta coefficients (one per modulator + intercept) describing how strongly each regulator, L-R pair or TF-ligand pair influences the target gene at that spatial location.
Pipeline overview:
estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn) estimator.fit(num_epochs=50) betas = estimator.get_betas() # DataFrame (cells × modulators) estimator.export("./models")
- Attributes:
adata (AnnData): The spatial single-cell dataset. target_gene (str): The gene being modelled. regulators (list[str]): Transcription-factor regulators of the target. ligands (list[str]): Ligand genes from active L-R pairs. receptors (list[str]): Receptor genes paired with
ligands. tfl_ligands (list[str]): Ligands from TF-ligand (NicheNet) pairs. tfl_regulators (list[str]): TFs paired withtfl_ligands. lr_pairs (pd.Series): Active ligand-receptor pair identifiers(format
"LigandGene$ReceptorGene").- tfl_pairs (list[str]): Active TF-ligand pair identifiers
(format
"LigandGene#TFGene").- modulators (list[str]): Ordered concatenation of regulators, L-R
pairs, and TF-ligand pairs – defines the column order of the beta matrix.
- modulators_genes (list[str]): Unique gene names spanning all
modulator categories.
- models (dict[int, CellularNicheNetwork]): Trained per-cluster neural
networks (available after
fit()).- scores (dict[int, float]): Per-cluster R² scores on the training set
(available after
fit()).- loss_dict (dict[int, list[float]]): Per-batch MSE losses collected
during training (available after
fit()).- spatial_maps (np.ndarray): Cached neighbourhood image tensors of
shape
(n_cells, n_clusters, spatial_dim, spatial_dim).- spatial_features (pd.DataFrame): Min-max scaled neighbour-count
features (shape
n_cells × n_clusters).
xy (np.ndarray): Spatial coordinates array of shape
(n_cells, 2). device (torch.device): PyTorch compute device (CUDA / MPS / CPU). species (str):"mouse"or"human"inferred from gene names.
- __init__(adata, target_gene, spatial_dim=64, cluster_annot='cell_type_int', layer='imputed_count', radius=100, contact_distance=30, use_ligands=True, tf_ligand_cutoff=0.01, receptor_thresh=0.1, regulators=None, grn=None, colinks_path=None, scale_factor=1, extra_regulators=None)[source]¶
Initialise the estimator and resolve regulatory inputs.
The constructor performs three main steps:
GRN lookup – fetch the list of transcription-factor regulators for
target_geneeither from an explicitregulatorslist or viagrn.get_regulators.Ligand-receptor discovery – query the CellChat database to enumerate L-R pairs and NicheNet TF-ligand pairs that are relevant to
target_gene(skipped whenuse_ligands=False).Modulator assembly – build the ordered
modulatorslist that defines the column ordering of the beta coefficient matrix.
- Args:
- adata (AnnData): Spatial dataset. Must contain
adata.obsm["spatial"], the requestedlayer, andcluster_annotinadata.obs.- target_gene (str): Name of the gene to model. Must be present in
adata.var_names.- spatial_dim (int): Side length (pixels) of the square spatial
neighbourhood image used as CNN input. Default
64.- cluster_annot (str):
adata.obscolumn holding integer-encoded cell-type labels. Default
"cell_type_int".- layer (str):
adata.layerskey for gene-expression values. Default
"imputed_count".- radius (float): Search radius (same spatial units as
adata.obsm["spatial"]) for secreted ligand diffusion. Default100.- contact_distance (float): Search radius for contact signalling
(juxtacrine / cell-surface). Default
30.- use_ligands (bool): When
False, skip all ligand-receptor and TF-ligand feature computation. Default
True.- tf_ligand_cutoff (float): Minimum NicheNet ligand-target score
required to include a TF-ligand pair. Default
0.01.- receptor_thresh (float): Reserved threshold parameter (currently
unused in filtering). Default
0.1.- regulators (list[str] | None): Explicit list of TF regulators.
When supplied,
grnandcolinks_pathare ignored.- grn (RegulatoryFactory | None): Pre-built GRN object exposing a
get_regulators(adata, gene)method. Mutually exclusive withcolinks_path.- colinks_path (str | None): Path to a co-link CSV used to
construct a
RegulatoryFactoryinternally. Required when bothgrnandregulatorsareNone.- scale_factor (float): Multiplicative weight applied to Gaussian
kernel values during ligand diffusion. Default
1.- extra_regulators (list[str] | None): Additional gene names to
append to the regulator list after GRN lookup.
- plot_modulators(use_expression=True)[source]¶
Visualise all modulator genes as a word cloud.
Renders a word cloud where each word is a gene that modulates the target (regulators, ligands, receptors, TF-ligand partners). Word size is proportional to mean expression when
use_expression=True. Colour encodes modulator category:viridis – ligands (L-R and TF-ligand).
magma – receptors.
rainbow – transcription factor regulators.
- Args:
- use_expression (bool): Scale word size by mean expression across
cells. When
Falseall words are equal size. DefaultTrue.
- Returns:
None: Displays the plot via
matplotlib.pyplot.show().
- static ligands_receptors_interactions(received_ligands_df, receptor_gex_df)[source]¶
Compute element-wise ligand × receptor interaction scores.
For each paired column
(ligand_i, receptor_i)the interaction score is the Hadamard (element-wise) product of the received ligand signal (Gaussian-weighted neighbourhood average) and the cell’s own receptor expression:score_{i,j} = received_ligand_{i,j} × receptor_expr_{i,j}The resulting columns are named
"LigandGene$ReceptorGene", matching the format used throughout the modulator pipeline.- Args:
- received_ligands_df (pd.DataFrame): DataFrame of shape
(n_cells, n_pairs)containing diffused ligand abundances (output ofreceived_ligands()).- receptor_gex_df (pd.DataFrame): DataFrame of the same shape
containing receptor expression values for each paired receptor.
- Returns:
pd.DataFrame: Interaction score matrix of shape
(n_cells, n_pairs)with columns in"Lig$Rec"format.
- static ligand_regulators_interactions(received_ligands_df, regulator_gex_df)[source]¶
Compute element-wise ligand × TF (NicheNet) interaction scores.
Analogous to
ligands_receptors_interactions()but for the NicheNet TF-ligand axis. For each(ligand_i, tf_i)pair the score is the product of the received ligand signal and the TF’s expression in the same cell:score_{i,j} = received_ligand_{i,j} × tf_expr_{i,j}Columns are named
"LigandGene#TFGene".- Args:
- received_ligands_df (pd.DataFrame): Received ligand signals,
shape
(n_cells, n_pairs).- regulator_gex_df (pd.DataFrame): TF expression values of the
same shape.
- Returns:
pd.DataFrame: Interaction score matrix of shape
(n_cells, n_pairs)with columns in"Lig#TF"format.
- static check_LR_properties(adata, layer)[source]¶
Retrieve expression counts and optional cell-type thresholds.
A lightweight helper used by
init_data()to unpack the two artefacts needed for L-R filtering: the full expression DataFrame and the per-cell-type expression threshold mask stored inadata.uns["cell_thresholds"].- Args:
adata (AnnData): The dataset to query. layer (str): Layer key for expression values.
- Returns:
- tuple[pd.DataFrame, pd.DataFrame | None]:
(counts_df, cell_thresholds).cell_thresholdsisNonewhen the key is absent fromadata.uns.
- init_data()[source]¶
Build all training matrices and cache them on
self.This method orchestrates the full feature-engineering pipeline and must be called (implicitly via
fit()) beforeget_betas()can be used. It performs the following steps in order:Compute (or reuse cached) received ligand signals and store them in
adata.uns["received_ligands"]/"received_ligands_tfl".Compute L-R interaction scores and store them in
adata.uns["ligand_receptor"].Compute TF-ligand interaction scores and store them in
adata.uns["ligand_regulator"].Build (or reuse cached) spatial neighbourhood image tensors
adata.obsm["spatial_maps"].Build (or reuse cached) spatial context features
adata.obsm["spatial_features"]and normalise them withMinMaxScaler.Construct the training DataFrame
self.train_dfcombining all feature groups.Filter out any L-R / TF-ligand pairs that produced zero-variance features.
Update
self.modulatorsandself.modulators_genesto reflect surviving pairs.
- Side effects:
Populates
self.spatial_maps,self.spatial_features,self.train_df,self.xy,self.xy_df,self.lr_pairs,self.tfl_pairs,self.ligands,self.receptors,self.tfl_ligands,self.tfl_regulators.- Returns:
- tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
(sp_maps, X, y, cluster_labels)ready for the training loop infit().sp_maps– shape(n_cells, n_clusters, D, D).X– design matrix of shape(n_cells, n_modulators).y– target expression vector of shape(n_cells,).cluster_labels– integer cluster label per cell.
- get_betas()[source]¶
Extract per-cell spatial beta coefficients from trained models.
Iterates over each unique cluster label, forwards the corresponding spatial neighbourhood images and spatial context features through the cluster’s
CellularNicheNetwork, and collects the returned beta vectors. Cells whose cluster has no trained model (e.g. clusters with R² < threshold) receive an all-zeros beta row.The beta matrix has columns:
"beta0"– the learned intercept / baseline expression."beta_<modulator>"– one column per entry inself.modulatorsfollowing the order: regulators → L-R pairs → TF-ligand pairs.
This method is decorated with
@torch.no_grad()so no gradients are computed during inference.- Returns:
pd.DataFrame: Shape
(n_cells, 1 + n_modulators)indexed byadata.obs.index. Rows are sorted to match the original cell order regardless of the per-cluster iteration order.
- property betadata¶
Alias for
get_betas()kept for backward compatibility.- Returns:
pd.DataFrame: Per-cell spatial beta coefficients (see
get_betas()for full description).
- fit(num_epochs=100, threshold_lambda=1e-06, learning_rate=0.005, batch_size=512, pbar=None, estimator='lasso', vision_model='cnn', score_threshold=0.2, l1_reg=1e-09, skip_clusters=None)[source]¶
Train per-cluster spatial cellular programme models.
For each unique cluster label the method:
Fits a seed estimator (Group Lasso, ARD, or Bayesian Ridge) on the cluster’s cells to obtain initial
_betasanchors.If the seed R² exceeds
score_threshold, trains aCellularNicheNetworkorCellularViTwith those anchors using Adam + MSE loss.If the neural-network R² after training remains below
score_threshold, the model anchors are zeroed out so the gene’s spatial betas collapse to the baseline (intercept only).
The seed estimator types are:
"lasso"(default) – Group Lasso with feature groups(regulators, L-R pairs, TF-ligand pairs)andthreshold_lambdaas the group regularisation strength."bayesian"–BayesianRidge."ard"–ARDRegression. Scales as O(n²) in samples; avoid for large datasets.
- Args:
- num_epochs (int): Number of neural-network training epochs per
cluster. Set to
0to skip neural training (seed betas only). Default100.- threshold_lambda (float): Group-Lasso regularisation strength
(and ARD
threshold_lambda). Default1e-6.
learning_rate (float): Adam learning rate. Default
5e-3. batch_size (int): Mini-batch size for the DataLoader.Default
512.- pbar: An
enlightenprogress-bar counter. WhenNone a new manager is created automatically.
- estimator (str): Seed estimator type; one of
"lasso", "bayesian","ard". Default"lasso".- vision_model (str): Neural architecture; one of
"cnn" (
CellularNicheNetwork) or"transformer"(CellularViT). Default"cnn".- score_threshold (float): Minimum seed R² required to proceed
with neural training. Clusters below this threshold receive zero-anchor models. Default
0.2.- l1_reg (float): Element-wise L1 penalty applied on top of the
Group Lasso group penalty. Default
1e-9.- skip_clusters (list[int] | None): Cluster integer labels to
skip entirely (progress bar is still advanced). Default
None.
- Returns:
None: Populates
self.models,self.scores, andself.loss_dictin-place.
- export(save_dir='./models')[source]¶
Serialise the trained estimator to disk.
PyTorch
Moduleobjects are not directly picklable in all configurations. This method works around that by converting each cluster model to a plaindictof{state_dict, anchors}before pickling the estimator, then saves the result as<save_dir>/<target_gene>_estimator.pkl.Usage:
estimator.fit(num_epochs=50) estimator.export("./output/models") # → ./output/models/Myc_estimator.pkl
- Args:
- save_dir (str): Directory path to save the
.pklfile. Created automatically if it does not exist. Default
"./models".
- save_dir (str): Directory path to save the
- Returns:
None
- load(path)[source]¶
Restore a previously exported estimator from disk.
Reads the pickle file written by
export(), copies all non-model attributes back ontoself, then reconstructs each per-clusterCellularNicheNetworkfrom its saved state dict and anchors.Usage:
estimator = SpatialCellularProgramsEstimator(adata, "Myc", grn=grn) estimator.load("./output/models/Myc_estimator.pkl") betas = estimator.get_betas()
- Args:
- path (str): Absolute or relative path to the
.pklfile produced by
export().
- path (str): Absolute or relative path to the
- Returns:
None: Modifies
selfin-place.