# _..._
# .' '. _
# / .-""-\ _/ \
# .-| /:. | | |
# | \ |:. /.-'-./
# | .-'-;:__.' =/
# .'= *=| _.='
# / _. | ;
# ;-.-'| \ |
# / | \ _\ _\
# \__/'._;. ==' ==\
# \ \ |
# / / /
# /-._/-._/
# \ `\ \
# `-._/._/
import os
import sys
import pickle
import functools
import time
import jscatter
import jscatter
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import enlighten
from datetime import timedelta
from tqdm import tqdm
from collections import defaultdict
from simple_slurm import Slurm # pyright: ignore[reportMissingImports]
from SpaceTravLR.tools.network import expand_paired_interactions, get_cellchat_db
from enum import Enum
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
class Status(Enum):
BORN = "Newly born"
BORED = "Ready but not doing anything"
RUNNING = "Running"
SUCCESS = "Completed everything gracefully"
FUBAR = "F- Up Beyond Repair"
"""
default output directory is 'output'
'output/input_data' stores all the inputs
'output/logs' stores logs
'output/betadata' stores the spatial gene-gene networks
methods with trailing underscores have side-effects but return Nothing
code philosophy is to fail early and loudly
"""
def catch_and_retry(retry=1):
def wrapper(f):
@functools.wraps(f)
def inner(*args, **kwargs):
for i in range(0, retry):
try:
return f(*args, **kwargs)
except Exception as e:
print(Status.FUBAR)
raise e
time.sleep(i+1)
return inner
return wrapper
catch_errors = catch_and_retry(retry=1) #alias
[docs]
class SpaceShip:
"""
SpaceShip is the main entry point for the SpaceTravLR analysis pipeline.
It manages the data, directory structure, and execution of the various
steps involved in spatial gene regulatory network inference and perturbation.
Parameters
----------
name : str, optional
Name of the project/analysis, by default 'AlienTissue'.
outdir : str, optional
Path to the output directory where results will be saved, by default './output'.
"""
[docs]
def __init__(self, name: str = 'AlienTissue', outdir: str = './output'):
self.name = name
self.outdir = outdir.rstrip("/\\")
self.manager = None
self.status_bar = None
self.status = Status.BORN
[docs]
@catch_errors
def process_adata_(self, adata: ad.AnnData, annot: str = 'cell_type'):
"""
Preprocesses the AnnData object for SpaceTravLR analysis.
This method checks for required fields, normalizes data if necessary,
computes PCA/Neighbors/UMAP if missing, and imputes gene expression
if needed. It saves the processed AnnData to the output directory.
Parameters
----------
adata : ad.AnnData
The AnnData object containing the spatial transcriptomics data.
annot : str, optional
The column name in `adata.obs` containing cell type annotations,
by default 'cell_type'.
"""
from .oracles import BaseTravLR
from .tools.utils import scale_adata, is_mouse_data
from .tools.network import encode_labels
if self.status_bar:
self.status_bar.update('📊 Processing AnnData: Validating input...')
assert isinstance(adata, ad.AnnData)
assert annot in adata.obs.columns
assert 'spatial' in adata.obsm
adata = adata.copy()
self.species = 'mouse' if is_mouse_data(adata) else 'human'
adata = scale_adata(adata)
adata.obs['cell_type_int'] = adata.obs[annot].apply(
lambda x: encode_labels(adata.obs[annot], reverse_dict=True)[x])
if 'X_umap' not in adata.obsm:
if self.status_bar:
self.status_bar.update('📊 Processing AnnData: Computing PCA, neighbors, and UMAP...')
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
if 'imputed_count' not in adata.layers:
if 'normalized_count' not in adata.layers:
if adata.X.max() > 100:
sc.pp.log1p(adata)
adata.layers['normalized_count'] = adata.X.copy()
BaseTravLR.impute_clusterwise(
adata,
annot=annot,
layer='normalized_count',
layer_added='imputed_count'
)
del adata.layers['normalized_count']
self.annot = annot
if self.status_bar:
self.status_bar.update('📊 Processing AnnData: Saving processed data...')
adata.write_h5ad(f'{self.outdir}/input_data/_adata.h5ad')
self.adata = adata
if self.status_bar:
self.status_bar.update('✅ Processing AnnData: Complete')
[docs]
def interactive_select(self, adata, size=10, annot='cell_type', mode='spatial'):
"""
Launches an interactive scatter plot for selecting cells.
Parameters
----------
adata : ad.AnnData
AnnData object.
size : int, optional
Point size, by default 10.
annot : str, optional
Color by annotation, by default 'cell_type'.
mode : str, optional
'spatial' or 'umap', by default 'spatial'.
Returns
-------
jscatter.Scatter
Interactive scatter plot widget.
"""
datadf_with_umap = adata.to_df().join(adata.obs).join(
pd.DataFrame(adata.obsm['spatial'], columns=['x', 'y'], index=adata.obs_names)
)
datadf_with_umap['umapX'] = adata.obsm['X_umap'][:,0]
datadf_with_umap['umapY'] = adata.obsm['X_umap'][:,1]
config = {'height': 800, 'width': 800, 'size': size}
if mode == 'spatial':
scatter = jscatter.Scatter(data=datadf_with_umap, x='x', y='y', **config).color(by=annot).legend(True)
elif mode == 'umap':
scatter = jscatter.Scatter(data=datadf_with_umap, x='umapX', y='umapY', **config).color(by=annot).legend(True)
else:
raise ValueError(f"Invalid mode: {mode}")
return scatter
[docs]
def load_base_cell_thresholds(self) -> pd.DataFrame:
df_ligrec = get_cellchat_db(self.species)
df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor']
expanded = expand_paired_interactions(df_ligrec)
genes = set(expanded.ligand) | set(expanded.receptor)
genes = list(genes)
return pd.DataFrame(
columns=genes,
index=self.adata.obs_names
).fillna(1).astype(int)
[docs]
@staticmethod
def load_base_GRN(species) -> pd.DataFrame:
assert species in ['human', 'mouse']
data_path = os.path.join(
os.path.dirname(__file__), '..', 'SpaceTravLR_data', f'{species}_base_grn.parquet')
df = pd.read_parquet(data_path)
# tf_columns = [col for col in df.columns if col not in ['peak_id', 'gene_short_name']]
# df = df.melt(
# id_vars=['gene_short_name'],
# value_vars=tf_columns,
# var_name='source',
# value_name='link').query(
# 'link == 1')[['source', 'gene_short_name']].rename(
# columns={'gene_short_name': 'target'})
# df['coef_mean'] = 1
# df['coef_abs'] = 1
# df['p'] = 1e-5
# df['-logp'] = 5
return df
[docs]
@catch_errors
def run_celloracle_(self, alpha=5):
"""
Runs CellOracle to infer the base Gene Regulatory Network (GRN).
It constructs a cluster-specific GRN based on the base network structure
and the expression data in the AnnData object.
Parameters
----------
alpha : int, optional
Regularization parameter for the model, by default 5.
"""
if self.status_bar:
self.status_bar.update('Building base GRN...')
import celloracle_tmp as co
adata = self.adata
oracle = co.Oracle()
adata.X = adata.layers["raw_count"].copy()
oracle.import_anndata_as_raw_count(
adata=adata,
cluster_column_name=self.annot,
embedding_name="X_umap"
)
oracle.pcs = [True]
oracle.k_knn_imputation = 1
oracle.knn = 1
base_GRN = self.load_base_GRN(self.species)
oracle.import_TF_data(TF_info_matrix=base_GRN)
if self.status_bar:
self.status_bar.update('Computing & filtering TF links...')
links = oracle.get_links(
cluster_name_for_GRN_unit=self.annot,
alpha=alpha,
verbose_level=0
)
links.filter_links()
oracle.get_cluster_specific_TFdict_from_Links(links_object=links)
self.links = links.links_dict
with open(f'{self.outdir}/input_data/celloracle_links.pkl', 'wb') as f:
pickle.dump(links.links_dict, f)
[docs]
@catch_errors
def run_commot_(self, radius=350):
"""
Runs COMMOT to infer spatial cell-cell communication.
This method identifies ligand-receptor interactions and computes their
spatial communication scores. It also computes received ligand signals
for each cell.
Parameters
----------
radius : int, optional
Spatial radius for communication in microns (or coordinate units),
by default 350.
"""
from .tools.network import expand_paired_interactions
from .tools.network import get_cellchat_db
from .models.parallel_estimators import init_received_ligands
import commot as ct
adata = self.adata
if 'cell_thresholds' not in adata.uns:
if self.status_bar:
self.status_bar.update('Loading ligand-receptor database...')
df_ligrec = get_cellchat_db(self.species)
df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor']
if self.status_bar:
self.status_bar.update('🔬 Commot: Expanding paired interactions...')
expanded = expand_paired_interactions(df_ligrec)
genes = set(expanded.ligand) | set(expanded.receptor)
genes = list(genes)
expanded = expanded[
expanded.ligand.isin(adata.var_names) & expanded.receptor.isin(adata.var_names)]
adata.X = adata.layers['normalized_count']
if self.status_bar:
self.status_bar.update('🔬 COMMOT: Computing spatial communication...')
ct.tl.spatial_communication(adata,
database_name='user_database',
df_ligrec=expanded,
dis_thr=radius,
heteromeric=False
)
expanded['rename'] = expanded['ligand'] + '-' + expanded['receptor']
if self.status_bar:
self.status_bar.update(f'Computing cluster communication for {len(expanded["rename"].unique())} pathways...')
unique_pathways = expanded['rename'].unique()
for idx, name in enumerate(unique_pathways):
if self.status_bar:
self.status_bar.update(f'🔬 Commot: Cluster communication {idx+1}/{len(unique_pathways)}: {name[:30]}...')
ct.tl.cluster_communication(
adata,
database_name='user_database',
pathway_name=name,
clustering='cell_type',
random_seed=42,
n_permutations=100
)
data_dict = defaultdict(dict)
for name in expanded['rename']:
data_dict[name]['communication_matrix'] = adata.uns[
f'commot_cluster-cell_type-user_database-{name}']['communication_matrix']
data_dict[name]['communication_pvalue'] = adata.uns[
f'commot_cluster-cell_type-user_database-{name}']['communication_pvalue']
with open(f'{self.outdir}/input_data/communication.pkl', 'wb') as f:
pickle.dump(data_dict, f)
info = data_dict
def get_sig_interactions(value_matrix, p_matrix, pval=0.3):
p_matrix = np.where(p_matrix < pval, 1, 0)
return value_matrix * p_matrix
if self.status_bar:
self.status_bar.update('Processing significant interactions...')
interactions = {}
for lig, rec in tqdm(zip(expanded['ligand'], expanded['receptor'])):
name = lig + '-' + rec
if name in info.keys():
value_matrix = info[name]['communication_matrix']
p_matrix = info[name]['communication_pvalue']
sig_matrix = get_sig_interactions(value_matrix, p_matrix)
if sig_matrix.sum().sum() > 0:
interactions[name] = sig_matrix
if self.status_bar:
self.status_bar.update('Computing ligand-receptor thresholds...')
ct_masks = {cell_type: adata.obs[self.annot] == cell_type for cell_type in adata.obs[self.annot].unique()}
df = pd.DataFrame(index=adata.obs_names, columns=genes)
df = df.fillna(0)
for name in tqdm(interactions.keys(), total=len(interactions)):
lig, rec = name.rsplit('-', 1)
tmp = interactions[name].sum(axis=1)
for cell_type, val in zip(interactions[name].index, tmp):
df.loc[ct_masks[cell_type], lig] += tmp[cell_type]
tmp = interactions[name].sum(axis=0)
for cell_type, val in zip(interactions[name].columns, tmp):
df.loc[ct_masks[cell_type], rec] += tmp[cell_type]
perc_filtered = np.where(df > 0, 1, 0).sum().sum() / (df.shape[0] * df.shape[1])
df.to_parquet(f'{self.outdir}/input_data/LRs.parquet')
adata.uns['cell_thresholds'] = df.copy()
else:
print('Cell thresholds already computed, skipping COMMOT...')
df = adata.uns['cell_thresholds']
if self.status_bar:
self.status_bar.update('Caching received ligands...')
adata = init_received_ligands(
adata,
radius=radius,
cell_threshes=df
)
keys = list(adata.obsm.keys())
for key in keys:
if 'commot' in key:
del adata.obsm[key]
keys = list(adata.uns.keys())
for key in keys:
if 'commot' in key:
del adata.uns[key]
keys = list(adata.obsp.keys())
for key in keys:
if 'commot' in key:
del adata.obsp[key]
self.adata = adata.copy()
adata.write_h5ad(f'{self.outdir}/input_data/_adata.h5ad')
self.status = Status.BORED
[docs]
def setup_(self, adata: ad.AnnData, overwrite=False, run_commot=False):
"""
Sets up the SpaceShip environment and runs the preprocessing pipeline.
This includes creating directories, processing AnnData, running CellOracle,
and running COMMOT.
Parameters
----------
adata : ad.AnnData
Input AnnData object.
overwrite : bool, optional
If True, overwrites existing output directory, by default False.
Returns
-------
self
Returns self for method chaining.
"""
if os.path.exists(self.outdir) and not overwrite:
print("Warning: output directory already exists. Will not overwrite.")
self.status = Status.FUBAR
return
self.manager = enlighten.get_manager()
self.status_bar = self.manager.status_bar(
f'🚀 SpaceShip {self.name}: Initializing...',
color='black_on_cyan',
justify=enlighten.Justify.CENTER,
auto_refresh=True
)
if self.status_bar:
self.status_bar.update('🚀 SpaceShip: Creating output directories...')
os.makedirs(self.outdir, exist_ok=True)
os.makedirs(f'{self.outdir}/betadata', exist_ok=True)
os.makedirs(f'{self.outdir}/input_data', exist_ok=True)
os.makedirs(f'{self.outdir}/logs', exist_ok=True)
self.status = Status.RUNNING
self.process_adata_(adata)
self.run_celloracle_()
if run_commot:
self.run_commot_()
self.get_nichenet_links_()
if self.status_bar:
self.status_bar.update('✅ SpaceShip: Setup complete!')
self.status = Status.BORED
return self
[docs]
def get_nichenet_links_(self):
if self.status_bar:
self.status_bar.update('🔗 NicheNet: Downloading ligand-target links...')
data_path = f'https://zenodo.org/records/17594271/files/ligand_target_{self.species}.parquet'
nichenet_lt = pd.read_parquet(data_path)
if self.status_bar:
self.status_bar.update('🔗 NicheNet: Saving links...')
nichenet_lt.to_parquet(f'{self.outdir}/input_data/tflinks.parquet')
if self.status_bar:
self.status_bar.update('✅ NicheNet: Complete')
return nichenet_lt
[docs]
def spawn_worker(
self,
partition='preempt',
clusters='gpu',
gres='gpu:1',
job_name='SpaceTravLR',
lifespan=3, # hours
python_path='python',
):
"""
Submits a SLURM job to run the analysis.
Parameters
----------
partition : str, optional
SLURM partition, by default 'preempt'.
clusters : str, optional
SLURM cluster, by default 'gpu'.
gres : str, optional
Generic Resource Scheduling (e.g. gpu:1), by default 'gpu:1'.
job_name : str, optional
Name of the job, by default 'SpaceTravLR'.
lifespan : int, optional
Wall-time in hours, by default 3.
python_path : str, optional
Path to python executable, by default 'python'.
"""
outlog = f'{self.outdir}/logs/training_{str(time.strftime("%Y%m%d_%H%M%S"))}.log'
slurm = Slurm(
cpus_per_task=1,
partition=partition,
clusters=clusters,
gres=gres,
ignore_pbs=True,
job_name=job_name+'_'+self.name,
output=outlog,
time=timedelta(hours=lifespan),
)
slurm.sbatch(python_path + ' launch.py')
[docs]
@catch_errors
def run_spacetravlr(
self,
max_epochs: int = 150,
learning_rate: float = 5e-3,
spatial_dim: int = 64,
batch_size: int = 512,
radius: int = 300,
contact_distance: int = 50,
):
"""
Trains the SpaceTravLR model to learn spatial gene regulation.
This method initializes and trains the SpaceTravLR neural network
model to predict gene expression based on TF activity and spatial
ligand-receptor interactions.
Parameters
----------
max_epochs : int, optional
Maximum number of training epochs, by default 150.
learning_rate : float, optional
Learning rate for the optimizer, by default 5e-3.
spatial_dim : int, optional
Dimension of the spatial embedding, by default 64.
batch_size : int, optional
Batch size for training, by default 512.
radius : int, optional
Radius for secreted signaling, by default 300.
contact_distance : int, optional
Distance for contact-dependent signaling, by default 50.
"""
from .oracles import SpaceTravLR
from .tools.network import RegulatoryFactory
base_dir = f'{self.outdir}/betadata/'
adata = sc.read_h5ad(f'{self.outdir}/input_data/_adata.h5ad')
tflinks = pd.read_parquet(f'{self.outdir}/input_data/tflinks.parquet')
links = pickle.load(open(f'{self.outdir}/input_data/celloracle_links.pkl', 'rb'))
co_grn = RegulatoryFactory(links=links)
space_travlr = SpaceTravLR(
adata=adata,
max_epochs=max_epochs,
learning_rate=learning_rate,
spatial_dim=spatial_dim,
batch_size=batch_size,
grn=co_grn,
radius=radius,
contact_distance=contact_distance,
save_dir=base_dir,
tflinks=tflinks
)
space_travlr.run()
#@alias
[docs]
def fit(self, **kwargs): return self.run_spacetravlr(**kwargs)
[docs]
def setup_perturbations(self, adata, override_params=None, subsample=None, use_float16=False):
"""
Initializes the GeneFactory for running perturbations.
Parameters
----------
adata : ad.AnnData
AnnData object used for perturbation simulations.
override_params : dict, optional
Dictionary to override run parameters, by default None.
subsample : int, optional
Number of cells to subsample for faster loading, by default None.
use_float16 : bool, optional
Use float16 for lower memory usage, by default False.
"""
from .gene_factory import GeneFactory
json_path = f'{self.outdir}/betadata/run_params.json'
assert os.path.exists(json_path), f"run_params.json not found"
self.factory = GeneFactory.from_json(
adata=adata,
json_path=json_path,
override_params=override_params
)
self.factory.load_betas(subsample=subsample, float16=use_float16)
[docs]
def perturb(self, target, propagation=4, gene_expr=0, cells=None):
"""
Performs in silico perturbation of a target gene.
Simulates the effect of changing a gene's expression (knockout or
overexpression) on the entire transcriptome, considering spatial
signaling propagation.
Parameters
----------
target : str or list
Target gene(s) to perturb.
propagation : int, optional
Number of propagation steps (hops) in the network, by default 4.
gene_expr : float or list, optional
Target expression level (0 for knockout), by default 0.
cells : list, optional
List of cell indices to apply perturbation to (None for all cells),
by default None.
Returns
-------
pd.DataFrame
Simulated gene expression matrix after perturbation.
"""
return self.factory.perturb(
target=target,
n_propagation=propagation,
gene_expr=gene_expr,
cells=cells
)
[docs]
def is_everything_ok(self) -> bool:
"""
Checks if all necessary output files and directories exist.
Returns
-------
bool
True if all checks pass.
"""
assert os.path.isfile(f'{self.outdir}/input_data/_adata.h5ad'), "AnnData file not found"
_adata = sc.read_h5ad(f'{self.outdir}/input_data/_adata.h5ad')
_links = pickle.load(open(f'{self.outdir}/input_data/celloracle_links.pkl', 'rb'))
assert 'imputed_count' in _adata.layers, "Imputed count layer not found"
assert 'X_umap' in _adata.obsm, "UMAP embedding not found"
assert 'cell_type_int' in _adata.obs.columns, "Cell type integer column not found"
assert 'spatial' in _adata.obsm, "Spatial coordinates not found"
assert os.path.isdir(self.outdir), "Output directory not found"
assert os.path.isdir(f'{self.outdir}/betadata'), "Betadata directory not found"
assert os.path.isdir(f'{self.outdir}/input_data'), "Input data directory not found"
assert os.path.isfile(f'{self.outdir}/input_data/celloracle_links.pkl'), "Base links file not found"
assert os.path.isdir(f'{self.outdir}/logs'), "Logs directory not found"
assert os.path.isfile('launch.py'), "Launch script not found"
print("We're going on a trip in our favorite rocket ship 🚀️")
return True