Source code for SpaceTravLR.spaceship

#                                _..._
#                            .'     '.      _
#                            /    .-""-\   _/ \
#                        .-|   /:.   |  |   |
#                        |  \  |:.   /.-'-./
#                        | .-'-;:__.'    =/
#                        .'=  *=|     _.='
#                        /   _.  |    ;
#                        ;-.-'|    \   |
#                        /   | \    _\  _\
#                        \__/'._;.  ==' ==\
#                                \    \   |
#                                /    /   /
#                                /-._/-._/
#                                \   `\  \
#                                `-._/._/


import os
import sys 
import pickle
import functools
import time

import jscatter
import jscatter
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import enlighten

from datetime import timedelta
from tqdm import tqdm
from collections import defaultdict
from simple_slurm import Slurm  # pyright: ignore[reportMissingImports]

from SpaceTravLR.tools.network import expand_paired_interactions, get_cellchat_db

from enum import Enum

import warnings
warnings.filterwarnings('ignore', category=UserWarning)

class Status(Enum):
    BORN        =   "Newly born"
    BORED       =   "Ready but not doing anything"
    RUNNING     =   "Running"
    SUCCESS     =   "Completed everything gracefully"
    FUBAR       =   "F- Up Beyond Repair"

""" 
default output directory is 'output' 
'output/input_data' stores all the inputs
'output/logs' stores logs
'output/betadata'  stores the spatial gene-gene networks

methods with trailing underscores have side-effects but return Nothing
code philosophy is to fail early and loudly
"""


def catch_and_retry(retry=1):
    def wrapper(f):
        @functools.wraps(f)
        def inner(*args, **kwargs):
            for i in range(0, retry):
                try:
                    return f(*args, **kwargs)
                except Exception as e:
                    print(Status.FUBAR)
                    raise e
                    time.sleep(i+1)
        return inner
    return wrapper

catch_errors = catch_and_retry(retry=1) #alias

[docs] class SpaceShip: """ SpaceShip is the main entry point for the SpaceTravLR analysis pipeline. It manages the data, directory structure, and execution of the various steps involved in spatial gene regulatory network inference and perturbation. Parameters ---------- name : str, optional Name of the project/analysis, by default 'AlienTissue'. outdir : str, optional Path to the output directory where results will be saved, by default './output'. """
[docs] def __init__(self, name: str = 'AlienTissue', outdir: str = './output'): self.name = name self.outdir = outdir.rstrip("/\\") self.manager = None self.status_bar = None self.status = Status.BORN
[docs] @catch_errors def process_adata_(self, adata: ad.AnnData, annot: str = 'cell_type'): """ Preprocesses the AnnData object for SpaceTravLR analysis. This method checks for required fields, normalizes data if necessary, computes PCA/Neighbors/UMAP if missing, and imputes gene expression if needed. It saves the processed AnnData to the output directory. Parameters ---------- adata : ad.AnnData The AnnData object containing the spatial transcriptomics data. annot : str, optional The column name in `adata.obs` containing cell type annotations, by default 'cell_type'. """ from .oracles import BaseTravLR from .tools.utils import scale_adata, is_mouse_data from .tools.network import encode_labels if self.status_bar: self.status_bar.update('📊 Processing AnnData: Validating input...') assert isinstance(adata, ad.AnnData) assert annot in adata.obs.columns assert 'spatial' in adata.obsm adata = adata.copy() self.species = 'mouse' if is_mouse_data(adata) else 'human' adata = scale_adata(adata) adata.obs['cell_type_int'] = adata.obs[annot].apply( lambda x: encode_labels(adata.obs[annot], reverse_dict=True)[x]) if 'X_umap' not in adata.obsm: if self.status_bar: self.status_bar.update('📊 Processing AnnData: Computing PCA, neighbors, and UMAP...') sc.pp.pca(adata) sc.pp.neighbors(adata) sc.tl.umap(adata) if 'imputed_count' not in adata.layers: if 'normalized_count' not in adata.layers: if adata.X.max() > 100: sc.pp.log1p(adata) adata.layers['normalized_count'] = adata.X.copy() BaseTravLR.impute_clusterwise( adata, annot=annot, layer='normalized_count', layer_added='imputed_count' ) del adata.layers['normalized_count'] self.annot = annot if self.status_bar: self.status_bar.update('📊 Processing AnnData: Saving processed data...') adata.write_h5ad(f'{self.outdir}/input_data/_adata.h5ad') self.adata = adata if self.status_bar: self.status_bar.update('✅ Processing AnnData: Complete')
[docs] def interactive_select(self, adata, size=10, annot='cell_type', mode='spatial'): """ Launches an interactive scatter plot for selecting cells. Parameters ---------- adata : ad.AnnData AnnData object. size : int, optional Point size, by default 10. annot : str, optional Color by annotation, by default 'cell_type'. mode : str, optional 'spatial' or 'umap', by default 'spatial'. Returns ------- jscatter.Scatter Interactive scatter plot widget. """ datadf_with_umap = adata.to_df().join(adata.obs).join( pd.DataFrame(adata.obsm['spatial'], columns=['x', 'y'], index=adata.obs_names) ) datadf_with_umap['umapX'] = adata.obsm['X_umap'][:,0] datadf_with_umap['umapY'] = adata.obsm['X_umap'][:,1] config = {'height': 800, 'width': 800, 'size': size} if mode == 'spatial': scatter = jscatter.Scatter(data=datadf_with_umap, x='x', y='y', **config).color(by=annot).legend(True) elif mode == 'umap': scatter = jscatter.Scatter(data=datadf_with_umap, x='umapX', y='umapY', **config).color(by=annot).legend(True) else: raise ValueError(f"Invalid mode: {mode}") return scatter
[docs] def load_base_cell_thresholds(self) -> pd.DataFrame: df_ligrec = get_cellchat_db(self.species) df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor'] expanded = expand_paired_interactions(df_ligrec) genes = set(expanded.ligand) | set(expanded.receptor) genes = list(genes) return pd.DataFrame( columns=genes, index=self.adata.obs_names ).fillna(1).astype(int)
[docs] @staticmethod def load_base_GRN(species) -> pd.DataFrame: assert species in ['human', 'mouse'] data_path = os.path.join( os.path.dirname(__file__), '..', 'SpaceTravLR_data', f'{species}_base_grn.parquet') df = pd.read_parquet(data_path) # tf_columns = [col for col in df.columns if col not in ['peak_id', 'gene_short_name']] # df = df.melt( # id_vars=['gene_short_name'], # value_vars=tf_columns, # var_name='source', # value_name='link').query( # 'link == 1')[['source', 'gene_short_name']].rename( # columns={'gene_short_name': 'target'}) # df['coef_mean'] = 1 # df['coef_abs'] = 1 # df['p'] = 1e-5 # df['-logp'] = 5 return df
[docs] @catch_errors def run_celloracle_(self, alpha=5): """ Runs CellOracle to infer the base Gene Regulatory Network (GRN). It constructs a cluster-specific GRN based on the base network structure and the expression data in the AnnData object. Parameters ---------- alpha : int, optional Regularization parameter for the model, by default 5. """ if self.status_bar: self.status_bar.update('Building base GRN...') import celloracle_tmp as co adata = self.adata oracle = co.Oracle() adata.X = adata.layers["raw_count"].copy() oracle.import_anndata_as_raw_count( adata=adata, cluster_column_name=self.annot, embedding_name="X_umap" ) oracle.pcs = [True] oracle.k_knn_imputation = 1 oracle.knn = 1 base_GRN = self.load_base_GRN(self.species) oracle.import_TF_data(TF_info_matrix=base_GRN) if self.status_bar: self.status_bar.update('Computing & filtering TF links...') links = oracle.get_links( cluster_name_for_GRN_unit=self.annot, alpha=alpha, verbose_level=0 ) links.filter_links() oracle.get_cluster_specific_TFdict_from_Links(links_object=links) self.links = links.links_dict with open(f'{self.outdir}/input_data/celloracle_links.pkl', 'wb') as f: pickle.dump(links.links_dict, f)
[docs] @catch_errors def run_commot_(self, radius=350): """ Runs COMMOT to infer spatial cell-cell communication. This method identifies ligand-receptor interactions and computes their spatial communication scores. It also computes received ligand signals for each cell. Parameters ---------- radius : int, optional Spatial radius for communication in microns (or coordinate units), by default 350. """ from .tools.network import expand_paired_interactions from .tools.network import get_cellchat_db from .models.parallel_estimators import init_received_ligands import commot as ct adata = self.adata if 'cell_thresholds' not in adata.uns: if self.status_bar: self.status_bar.update('Loading ligand-receptor database...') df_ligrec = get_cellchat_db(self.species) df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor'] if self.status_bar: self.status_bar.update('🔬 Commot: Expanding paired interactions...') expanded = expand_paired_interactions(df_ligrec) genes = set(expanded.ligand) | set(expanded.receptor) genes = list(genes) expanded = expanded[ expanded.ligand.isin(adata.var_names) & expanded.receptor.isin(adata.var_names)] adata.X = adata.layers['normalized_count'] if self.status_bar: self.status_bar.update('🔬 COMMOT: Computing spatial communication...') ct.tl.spatial_communication(adata, database_name='user_database', df_ligrec=expanded, dis_thr=radius, heteromeric=False ) expanded['rename'] = expanded['ligand'] + '-' + expanded['receptor'] if self.status_bar: self.status_bar.update(f'Computing cluster communication for {len(expanded["rename"].unique())} pathways...') unique_pathways = expanded['rename'].unique() for idx, name in enumerate(unique_pathways): if self.status_bar: self.status_bar.update(f'🔬 Commot: Cluster communication {idx+1}/{len(unique_pathways)}: {name[:30]}...') ct.tl.cluster_communication( adata, database_name='user_database', pathway_name=name, clustering='cell_type', random_seed=42, n_permutations=100 ) data_dict = defaultdict(dict) for name in expanded['rename']: data_dict[name]['communication_matrix'] = adata.uns[ f'commot_cluster-cell_type-user_database-{name}']['communication_matrix'] data_dict[name]['communication_pvalue'] = adata.uns[ f'commot_cluster-cell_type-user_database-{name}']['communication_pvalue'] with open(f'{self.outdir}/input_data/communication.pkl', 'wb') as f: pickle.dump(data_dict, f) info = data_dict def get_sig_interactions(value_matrix, p_matrix, pval=0.3): p_matrix = np.where(p_matrix < pval, 1, 0) return value_matrix * p_matrix if self.status_bar: self.status_bar.update('Processing significant interactions...') interactions = {} for lig, rec in tqdm(zip(expanded['ligand'], expanded['receptor'])): name = lig + '-' + rec if name in info.keys(): value_matrix = info[name]['communication_matrix'] p_matrix = info[name]['communication_pvalue'] sig_matrix = get_sig_interactions(value_matrix, p_matrix) if sig_matrix.sum().sum() > 0: interactions[name] = sig_matrix if self.status_bar: self.status_bar.update('Computing ligand-receptor thresholds...') ct_masks = {cell_type: adata.obs[self.annot] == cell_type for cell_type in adata.obs[self.annot].unique()} df = pd.DataFrame(index=adata.obs_names, columns=genes) df = df.fillna(0) for name in tqdm(interactions.keys(), total=len(interactions)): lig, rec = name.rsplit('-', 1) tmp = interactions[name].sum(axis=1) for cell_type, val in zip(interactions[name].index, tmp): df.loc[ct_masks[cell_type], lig] += tmp[cell_type] tmp = interactions[name].sum(axis=0) for cell_type, val in zip(interactions[name].columns, tmp): df.loc[ct_masks[cell_type], rec] += tmp[cell_type] perc_filtered = np.where(df > 0, 1, 0).sum().sum() / (df.shape[0] * df.shape[1]) df.to_parquet(f'{self.outdir}/input_data/LRs.parquet') adata.uns['cell_thresholds'] = df.copy() else: print('Cell thresholds already computed, skipping COMMOT...') df = adata.uns['cell_thresholds'] if self.status_bar: self.status_bar.update('Caching received ligands...') adata = init_received_ligands( adata, radius=radius, cell_threshes=df ) keys = list(adata.obsm.keys()) for key in keys: if 'commot' in key: del adata.obsm[key] keys = list(adata.uns.keys()) for key in keys: if 'commot' in key: del adata.uns[key] keys = list(adata.obsp.keys()) for key in keys: if 'commot' in key: del adata.obsp[key] self.adata = adata.copy() adata.write_h5ad(f'{self.outdir}/input_data/_adata.h5ad') self.status = Status.BORED
[docs] def setup_(self, adata: ad.AnnData, overwrite=False, run_commot=False): """ Sets up the SpaceShip environment and runs the preprocessing pipeline. This includes creating directories, processing AnnData, running CellOracle, and running COMMOT. Parameters ---------- adata : ad.AnnData Input AnnData object. overwrite : bool, optional If True, overwrites existing output directory, by default False. Returns ------- self Returns self for method chaining. """ if os.path.exists(self.outdir) and not overwrite: print("Warning: output directory already exists. Will not overwrite.") self.status = Status.FUBAR return self.manager = enlighten.get_manager() self.status_bar = self.manager.status_bar( f'🚀 SpaceShip {self.name}: Initializing...', color='black_on_cyan', justify=enlighten.Justify.CENTER, auto_refresh=True ) if self.status_bar: self.status_bar.update('🚀 SpaceShip: Creating output directories...') os.makedirs(self.outdir, exist_ok=True) os.makedirs(f'{self.outdir}/betadata', exist_ok=True) os.makedirs(f'{self.outdir}/input_data', exist_ok=True) os.makedirs(f'{self.outdir}/logs', exist_ok=True) self.status = Status.RUNNING self.process_adata_(adata) self.run_celloracle_() if run_commot: self.run_commot_() self.get_nichenet_links_() if self.status_bar: self.status_bar.update('✅ SpaceShip: Setup complete!') self.status = Status.BORED return self
[docs] def spawn_worker( self, partition='preempt', clusters='gpu', gres='gpu:1', job_name='SpaceTravLR', lifespan=3, # hours python_path='python', ): """ Submits a SLURM job to run the analysis. Parameters ---------- partition : str, optional SLURM partition, by default 'preempt'. clusters : str, optional SLURM cluster, by default 'gpu'. gres : str, optional Generic Resource Scheduling (e.g. gpu:1), by default 'gpu:1'. job_name : str, optional Name of the job, by default 'SpaceTravLR'. lifespan : int, optional Wall-time in hours, by default 3. python_path : str, optional Path to python executable, by default 'python'. """ outlog = f'{self.outdir}/logs/training_{str(time.strftime("%Y%m%d_%H%M%S"))}.log' slurm = Slurm( cpus_per_task=1, partition=partition, clusters=clusters, gres=gres, ignore_pbs=True, job_name=job_name+'_'+self.name, output=outlog, time=timedelta(hours=lifespan), ) slurm.sbatch(python_path + ' launch.py')
[docs] @catch_errors def run_spacetravlr( self, max_epochs: int = 150, learning_rate: float = 5e-3, spatial_dim: int = 64, batch_size: int = 512, radius: int = 300, contact_distance: int = 50, ): """ Trains the SpaceTravLR model to learn spatial gene regulation. This method initializes and trains the SpaceTravLR neural network model to predict gene expression based on TF activity and spatial ligand-receptor interactions. Parameters ---------- max_epochs : int, optional Maximum number of training epochs, by default 150. learning_rate : float, optional Learning rate for the optimizer, by default 5e-3. spatial_dim : int, optional Dimension of the spatial embedding, by default 64. batch_size : int, optional Batch size for training, by default 512. radius : int, optional Radius for secreted signaling, by default 300. contact_distance : int, optional Distance for contact-dependent signaling, by default 50. """ from .oracles import SpaceTravLR from .tools.network import RegulatoryFactory base_dir = f'{self.outdir}/betadata/' adata = sc.read_h5ad(f'{self.outdir}/input_data/_adata.h5ad') tflinks = pd.read_parquet(f'{self.outdir}/input_data/tflinks.parquet') links = pickle.load(open(f'{self.outdir}/input_data/celloracle_links.pkl', 'rb')) co_grn = RegulatoryFactory(links=links) space_travlr = SpaceTravLR( adata=adata, max_epochs=max_epochs, learning_rate=learning_rate, spatial_dim=spatial_dim, batch_size=batch_size, grn=co_grn, radius=radius, contact_distance=contact_distance, save_dir=base_dir, tflinks=tflinks ) space_travlr.run()
#@alias
[docs] def fit(self, **kwargs): return self.run_spacetravlr(**kwargs)
[docs] def setup_perturbations(self, adata, override_params=None, subsample=None, use_float16=False): """ Initializes the GeneFactory for running perturbations. Parameters ---------- adata : ad.AnnData AnnData object used for perturbation simulations. override_params : dict, optional Dictionary to override run parameters, by default None. subsample : int, optional Number of cells to subsample for faster loading, by default None. use_float16 : bool, optional Use float16 for lower memory usage, by default False. """ from .gene_factory import GeneFactory json_path = f'{self.outdir}/betadata/run_params.json' assert os.path.exists(json_path), f"run_params.json not found" self.factory = GeneFactory.from_json( adata=adata, json_path=json_path, override_params=override_params ) self.factory.load_betas(subsample=subsample, float16=use_float16)
[docs] def perturb(self, target, propagation=4, gene_expr=0, cells=None): """ Performs in silico perturbation of a target gene. Simulates the effect of changing a gene's expression (knockout or overexpression) on the entire transcriptome, considering spatial signaling propagation. Parameters ---------- target : str or list Target gene(s) to perturb. propagation : int, optional Number of propagation steps (hops) in the network, by default 4. gene_expr : float or list, optional Target expression level (0 for knockout), by default 0. cells : list, optional List of cell indices to apply perturbation to (None for all cells), by default None. Returns ------- pd.DataFrame Simulated gene expression matrix after perturbation. """ return self.factory.perturb( target=target, n_propagation=propagation, gene_expr=gene_expr, cells=cells )
[docs] def is_everything_ok(self) -> bool: """ Checks if all necessary output files and directories exist. Returns ------- bool True if all checks pass. """ assert os.path.isfile(f'{self.outdir}/input_data/_adata.h5ad'), "AnnData file not found" _adata = sc.read_h5ad(f'{self.outdir}/input_data/_adata.h5ad') _links = pickle.load(open(f'{self.outdir}/input_data/celloracle_links.pkl', 'rb')) assert 'imputed_count' in _adata.layers, "Imputed count layer not found" assert 'X_umap' in _adata.obsm, "UMAP embedding not found" assert 'cell_type_int' in _adata.obs.columns, "Cell type integer column not found" assert 'spatial' in _adata.obsm, "Spatial coordinates not found" assert os.path.isdir(self.outdir), "Output directory not found" assert os.path.isdir(f'{self.outdir}/betadata'), "Betadata directory not found" assert os.path.isdir(f'{self.outdir}/input_data'), "Input data directory not found" assert os.path.isfile(f'{self.outdir}/input_data/celloracle_links.pkl'), "Base links file not found" assert os.path.isdir(f'{self.outdir}/logs'), "Logs directory not found" assert os.path.isfile('launch.py'), "Launch script not found" print("We're going on a trip in our favorite rocket ship 🚀️") return True