Source code for SpaceTravLR.gene_factory


from collections import defaultdict
from functools import partial
import glob
import statistics
import numpy as np
import pandas as pd
from tqdm import tqdm
# import commot as ct
import gc
from SpaceTravLR.tools.network import expand_paired_interactions, get_cellchat_db
from SpaceTravLR.models.parallel_estimators import get_filtered_df, received_ligands
from SpaceTravLR.oracles import OracleQueue, BaseTravLR
from SpaceTravLR.beta import BetaFrame, Betabase
from SpaceTravLR.tools.utils import is_mouse_data
import enlighten
from pqdm.threads import pqdm
import datetime
import os
import warnings
warnings.filterwarnings('ignore')


            
[docs] class GeneFactory(BaseTravLR): """ GeneFactory handles the loading of trained models (betas) and facilitates in silico perturbations. It effectively acts as a factory for generating simulated gene expression profiles under various perturbation conditions. Parameters ---------- adata : ad.AnnData AnnData object containing the data. models_dir : str Directory where the trained models (betadata) are stored. annot : str, optional Annotation key in adata.obs, by default 'cell_type_int'. radius : int, optional Spatial radius for signaling, by default 200. contact_distance : int, optional Contact distance for signaling, by default 30. scale_factor : int, optional Scaling factor for spatial coordinates, by default 1. beta_scale_factor : int, optional Scaling factor for beta values, by default 1. beta_cap : float, optional Cap for beta values to prevent explosions in simulation, by default None. co_grn : object, optional CellOracle GRN object, by default None. """
[docs] def __init__( self, adata, models_dir, annot='cell_type_int', radius=200, contact_distance=30, scale_factor=1, beta_scale_factor=1, beta_cap=None, co_grn=None ): super().__init__(adata, fields_to_keep=[annot]) self.adata = adata.copy() self.annot = annot self.save_dir = models_dir self.radius = radius self.contact_distance = contact_distance self.species = 'mouse' if is_mouse_data(adata) else 'human' self.scale_factor = scale_factor self.queue = OracleQueue(models_dir, all_genes=self.adata.var_names) self.ligands = [] self.genes = list(self.adata.var_names) self.trained_genes = [] self.beta_dict = None self._name = 'GeneFactory' self.beta_scale_factor = beta_scale_factor self.beta_cap = beta_cap if co_grn is not None: flat_links = pd.concat([co_grn.links[ct] for ct in co_grn.links.keys()], axis=0) flat_links.sort_values(by='coef_mean', ascending=False, inplace=True) flat_links.drop_duplicates(subset=['source', 'target'], inplace=True, keep='first') self.co_grn_links = flat_links else: self.co_grn_links = None self.manager = enlighten.get_manager() self._logo = '🐭️🧬️️' if self.species == 'mouse' else '🙅‍♂️🧬️️' self._logo = f'{self._logo} {self._name}' self.status = self.manager.status_bar( f'{self._logo}: [Ready] | {adata.shape[0]} cells / {len(self.genes)} genes', color='black_on_green', justify=enlighten.Justify.CENTER, auto_refresh=True, width=30 ) self.xy = pd.DataFrame( self.adata.obsm['spatial'], index=self.adata.obs_names, columns=['x', 'y'] ) # df_ligrec = ct.pp.ligand_receptor_database( # database='CellChat', # species=self.species, # signaling_type=None # ) # df_ligrec.columns = ['ligand', 'receptor', 'pathway', 'signaling'] df_ligrec = get_cellchat_db(self.species) self.lr = expand_paired_interactions(df_ligrec) self.lr = self.lr[ self.lr.ligand.isin(self.adata.var_names) & ( self.lr.receptor.isin(self.adata.var_names))] self.lr['radius'] = np.where( self.lr['signaling'] == 'Secreted Signaling', self.radius, self.contact_distance )
[docs] @classmethod def from_json(cls, adata, json_path, override_params=None, beta_scale_factor=1, beta_cap=None, co_grn=None): """ Creates a GeneFactory instance from a parameters JSON file. Parameters ---------- adata : ad.AnnData AnnData object. json_path : str Path to the JSON file containing run parameters. override_params : dict, optional Dictionary to override parameters from JSON, by default None. beta_scale_factor : int, optional Scaling factor for beta values, by default 1. beta_cap : float, optional Cap for beta values, by default None. co_grn : object, optional CellOracle GRN object, by default None. Returns ------- GeneFactory Initialized GeneFactory instance. """ import json with open(json_path, 'r') as f: params = json.load(f) if override_params is not None: params.update(override_params) return cls( adata, models_dir=params['save_dir'], annot=params['annot'], radius=params['radius'], contact_distance=params['contact_distance'], scale_factor=params.get('scale_factor', 1), beta_scale_factor=beta_scale_factor, beta_cap=beta_cap, co_grn=co_grn )
## backwards compatibility
[docs] def compute_betas(self, **kwargs): self.load_betas(**kwargs)
[docs] def load_betas(self, subsample=None, float16=False, obs_names=None): """ Loads the spatial gene regulatory coefficients (betas) from disk. Parameters ---------- subsample : int, optional Number of cells to subsample, by default None. float16 : bool, optional Use float16 precision to save memory, by default False. obs_names : list, optional List of cell names to load betas for, by default None. """ self.beta_dict = None del self.beta_dict obs_names = obs_names if obs_names is not None else self.adata.obs_names self.status.update( '💾️ Loading betas from disk' + f' {len(obs_names)} cells') self.status.color = 'black_on_salmon' self.status.refresh() self.beta_dict = self._get_spatial_betas_dict( subsample=subsample, float16=float16, obs_names=obs_names ) self.obs_names = obs_names self.status.update('Loading betas - Done') self.status.color = 'black_on_green' self.status.refresh()
[docs] @staticmethod def load_betadata(gene, save_dir, obs_names=None): return BetaFrame.from_path(f'{save_dir}/{gene}_betadata.parquet', obs_names=obs_names)
def _compute_weighted_ligands(self, gene_mtx, cell_thresholds, genes): self.update_status(f'{self.current_target} >> Computing received ligands', color='black_on_cyan') gex_df = pd.DataFrame( gene_mtx, index=self.obs_names, columns=self.adata.var_names ) if len(genes) > 0: weighted_ligands = received_ligands( xy=self.adata[self.obs_names].obsm['spatial'], ligands_df=get_filtered_df(gex_df, cell_thresholds, genes), lr_info=self.lr, scale_factor=self.scale_factor ) else: weighted_ligands = pd.DataFrame(index=self.obs_names) return weighted_ligands
[docs] def update_status(self, msg='', color='black_on_green'): self.status.update(msg) self.status.color = color self.status.refresh()
def _get_wbetas_dict( self, betas_dict, weighted_ligands, weighted_ligands_tfl, gene_mtx, cell_thresholds): gex_df = get_filtered_df( # mask out receptors too counts_df=pd.DataFrame( gene_mtx, index=self.obs_names, columns=self.adata.var_names ), cell_thresholds=cell_thresholds, genes=self.adata.var_names )[self.adata.var_names] self.update_status( f'[{self.iter}/{self.max_iter}] | Computing Ligand interactions', color='black_on_salmon') out_dict = {} for i, (gene, betadata) in enumerate(betas_dict.data.items()): if self.co_grn_links is not None: grn_tfs = self.co_grn_links.loc[self.co_grn_links['source'] == gene, 'target'].values else: grn_tfs = None out_dict[gene] = self._combine_gene_wbetas( weighted_ligands, weighted_ligands_tfl, gex_df, betadata, grn_tfs=grn_tfs) if i % 250 == 0: self.update_status( f'{self.current_target} | {i}/{len(betas_dict.data)} | [{self.iter}/{self.max_iter}] | Computing Ligand interactions', color='black_on_salmon') self.update_status(f'Ligand interactions - Done') return out_dict def _combine_gene_wbetas(self, rw_ligands, rw_ligands_tfl, filtered_df, betadata, grn_tfs=None): betas_df = betadata.splash( rw_ligands, rw_ligands_tfl, filtered_df, scale_factor=self.beta_scale_factor, beta_cap=self.beta_cap, grn_tfs=grn_tfs ) return betas_df def _get_spatial_betas_dict(self, subsample=None, float16=False, obs_names=None, randomize=False): bdb = Betabase( self.adata, self.save_dir, subsample=subsample, float16=float16, obs_names=obs_names, randomize=randomize ) self.ligands = list(bdb.ligands_set) self.tfl_ligands = list(bdb.tfl_ligands_set) return bdb
[docs] def splash_betas(self, gene, obs_names=None): """ Computes the derivatives by splitting up ligand terms into individual gene components. This essentially converts betadata of cell x modulators into cell x genes. Parameters ---------- gene : str The gene to compute derivatives for. obs_names : list, optional List of cell names to compute derivatives for, by default all. Returns ------- pd.DataFrame DataFrame with derivatives for each cell at each location """ assert gene in self.adata.var_names if obs_names is None: obs_names = self.adata.obs_names rw_ligands = self.adata.uns.get('received_ligands').loc[obs_names] rw_tfligands = self.adata.uns.get('received_ligands_tfl').loc[obs_names] gene_mtx = self.adata.to_df(layer='imputed_count').loc[obs_names].values cell_thresholds = self.adata.uns.get('cell_thresholds').loc[obs_names] filtered_df = get_filtered_df( counts_df=pd.DataFrame( gene_mtx, index=obs_names, columns=self.adata.var_names ), cell_thresholds=cell_thresholds, genes=self.adata.var_names )[self.adata.var_names].loc[obs_names] betadata = self.load_betadata(gene, self.save_dir, obs_names=obs_names) return self._combine_gene_wbetas( rw_ligands, rw_tfligands, filtered_df, betadata)
def _perturb_all_cells(self, gex_delta, betas_dict): n_obs, n_genes = gex_delta.shape result = np.zeros((n_obs, n_genes)) n_vars = len(self.adata.var_names) for i, gene in enumerate(self.adata.var_names): if i % 250 == 0: self.update_status( f'[{self.iter}/{self.max_iter}] | Perturbing 🧬️🐝️ {i+1}/{n_vars} ', color='black_on_cyan' ) _beta_out = betas_dict.get(gene, None) if _beta_out is not None: mod_idx = self.beta_dict.data[gene].modulator_gene_indices result[:, i] = np.sum(_beta_out.values * gex_delta[:, mod_idx], axis=1) assert not np.isnan(result).any(), "NaN values found in delta_simulated" return result
[docs] def perturb( self, target, n_propagation=4, gene_expr=0, cells=None, save_layer=False, delta_dir=None, ): """ Simulates perturbation of a target gene and propagates the effect. Parameters ---------- target : str or list Target gene(s) to perturb. n_propagation : int, optional Number of propagation steps, by default 4. gene_expr : float, optional Expression level of the target gene (0 for knockout), by default 0. cells : list, optional List of cell indices to apply perturbation to, by default None. save_layer : bool, optional Whether to save the result as a layer in adata, by default False. delta_dir : str, optional Directory to save delta matrices, by default None. Returns ------- pd.DataFrame DataFrame containing the simulated gene expression. """ payload_dict = {} output_name = None if isinstance(target, str): assert isinstance(gene_expr, (int, float)) assert target in self.adata.var_names payload_dict[target] = gene_expr output_name = f'{target}_{n_propagation}n_{round(gene_expr, 2)}x' elif isinstance(target, list) and isinstance(gene_expr, list): assert len(target) == len(gene_expr) payload_dict = {t: g for t, g in zip(target, gene_expr)} output_name = '_'.join([f'{t}_{n_propagation}n_{round(g, 2)}x' for t, g in zip(target, gene_expr)]) else: raise ValueError(f'Invalid target info') self.current_target = output_name obs = self.obs_names gene_mtx = self.adata.to_df(layer='imputed_count').loc[obs] self.payload_dict = payload_dict if isinstance(gene_mtx, pd.DataFrame): gene_mtx = gene_mtx.values simulation_input = gene_mtx.copy() for target, gene_expr in self.payload_dict.items(): assert gene_expr >= 0 assert target in self.adata.var_names target_index = self.gene2index[target] if cells is None: simulation_input[:, target_index] = gene_expr else: # cells is a list of cell indices simulation_input[cells, target_index] = gene_expr delta_input = simulation_input - gene_mtx delta_simulated = delta_input.copy() if self.beta_dict is None: self.beta_dict = self._get_spatial_betas_dict(obs_names=self.obs_names) # get LR specific filtered gex contributions cell_thresholds = self.adata.uns.get('cell_thresholds') if cell_thresholds is not None: cell_thresholds = cell_thresholds.loc[obs].reindex( index=obs, columns=self.adata.var_names, fill_value=1) self.adata.uns['cell_thresholds'] = cell_thresholds # else: # print('warning: cell_thresholds not found in adata.uns') rw_ligands_0 = self.adata.uns.get('received_ligands') rw_tfligands_0 = self.adata.uns.get('received_ligands_tfl') if rw_ligands_0 is None or rw_tfligands_0 is None: rw_ligands_0 = self._compute_weighted_ligands( gene_mtx, cell_thresholds, genes=self.ligands) rw_tfligands_0 = self._compute_weighted_ligands( gene_mtx, cell_thresholds=None, genes=self.tfl_ligands) self.adata.uns['received_ligands'] = rw_ligands_0 self.adata.uns['received_ligands_tfl'] = rw_tfligands_0 all_ligands = list(set(self.ligands) | set(self.tfl_ligands)) ligands_0 = self.adata.to_df(layer='imputed_count')[all_ligands].reindex( index=self.obs_names, columns=self.adata.var_names, fill_value=0 ) # copy the original values gene_mtx_1 = gene_mtx.copy() rw_ligands_1 = rw_ligands_0.copy() rw_tfligands_1 = rw_tfligands_0.copy() # get the max weighted ligand expression (could be zeroed out in rw_ligands_0) w0 = rw_ligands_0.reindex(columns=self.adata.var_names, fill_value=0).values w0_tfl = rw_tfligands_0.reindex(columns=self.adata.var_names, fill_value=0).values rw_ligands_0 = pd.DataFrame( np.maximum(w0, w0_tfl), index=obs, columns=self.adata.var_names ) self.iter = 0 self.max_iter = n_propagation # min_ = gene_mtx.min(axis=0) min_ = 0.0 max_ = gene_mtx.max(axis=0) ## refer: src/celloracle/trajectory/oracle_GRN.py for n in range(n_propagation): self.iter+=1 self.update_status( f'{target} -> {gene_expr} - {n+1}/{n_propagation}', color='black_on_salmon') # weight betas by the gene expression from the previous iteration splashed_beta_dict = self._get_wbetas_dict( self.beta_dict, rw_ligands_1, rw_tfligands_1, gene_mtx_1, cell_thresholds) # get updated gene expressions gene_mtx_1 = gene_mtx + delta_simulated w_ligands_1 = self._compute_weighted_ligands( gene_mtx_1, cell_thresholds, genes=self.ligands) w_tfligands_1 = self._compute_weighted_ligands( gene_mtx_1, cell_thresholds=None, genes=self.tfl_ligands) # update deltas to reflect change in received ligands # we consider dy/dwL: we replace delta l with delta wL in delta_simulated w1 = w_ligands_1.reindex(columns=self.adata.var_names, fill_value=0).values w1_tfl = w_tfligands_1.reindex(columns=self.adata.var_names, fill_value=0).values rw_ligands_1 = pd.DataFrame( np.maximum(w1, w1_tfl), index=self.obs_names, columns=self.adata.var_names ) delta_rw_ligands = rw_ligands_1.values - rw_ligands_0.values # get the change in ligand expression within the gene_df that should be replaced with rw_ligand gene_df_1 = pd.DataFrame( gene_mtx_1, columns=self.adata.var_names, index=obs ) ligands_1 = gene_df_1[all_ligands].reindex( index=obs, columns=self.adata.var_names, fill_value=0 ) delta_ligands = ligands_1.values - ligands_0.values # the model sees delta wL, not delta L # delta_simulated contains delta L, so remove and replace with wL delta_simulated = delta_simulated + delta_rw_ligands - delta_ligands _simulated = self._perturb_all_cells(delta_simulated, splashed_beta_dict) delta_simulated = np.array(_simulated) # ensure values in delta_simulated match our desired KO / input delta_simulated = np.where(delta_input != 0, delta_input, delta_simulated) # Don't allow simulated to exceed observed values gem_tmp = gene_mtx + delta_simulated gem_tmp = np.clip(gem_tmp, a_min=min_, a_max=max_) delta_simulated = gem_tmp - gene_mtx # update delta_simulated in case of negative values if delta_dir: os.makedirs(delta_dir, exist_ok=True) np.save( f'{delta_dir}/{target}_{n}n_{gene_expr}x.npy', gene_mtx + delta_simulated ) del splashed_beta_dict # gc.collect() gem_simulated = gene_mtx + delta_simulated assert gem_simulated.shape == gene_mtx.shape for target_name, target_gene_expr in self.payload_dict.items(): target_index = self.gene2index[target_name] if cells is None: gem_simulated[:, target_index] = target_gene_expr else: gem_simulated[cells, target_index] = target_gene_expr if target_index % 5 == 0: self.update_status( f'{target_name} -> {target_gene_expr} - {n_propagation}/{n_propagation} - Done') if save_layer: self.adata.layers[output_name] = gem_simulated gex_out = pd.DataFrame(gem_simulated, index=obs, columns=self.adata.var_names) gex_out.index.name = output_name return gex_out
[docs] @staticmethod def get_ko_data(perturb_dir, adata): files = [i.split('/')[-1].split('_')[0] for i in glob.glob( f'{perturb_dir}/*.parquet')] ko_data = [] pbar = enlighten.get_manager().counter( total=len(files), desc='Getting KO data', unit='genes', color='orange', autorefresh=True, ) for kotarget in files: pbar.desc = f'Getting KO data - {kotarget}' pbar.refresh() data = pd.read_parquet(f'{perturb_dir}/{kotarget}_4n_0x.parquet') data = data.loc[adata.obs_names] - adata.to_df(layer='imputed_count') data = data.join(adata.obs.cell_type).groupby('cell_type').mean().abs().mean(axis=1) ds = {} for k, v in data.sort_values(ascending=False).to_dict().items(): ds[k] = v data = pd.DataFrame.from_dict(ds, orient='index') data.columns = [kotarget] ko_data.append(data) pbar.update() pbar.close() return pd.concat(ko_data, axis=1)
[docs] def perturb_batch( self, target_genes, save_to=None, n_propagation=4, gene_expr=0, cells=None): """ Runs perturbations for a batch of target genes. Parameters ---------- target_genes : list List of genes to perturb. save_to : str, optional Directory to save results, by default None. n_propagation : int, optional Number of propagation steps, by default 4. gene_expr : float, optional Target expression level, by default 0. cells : list, optional List of cells to apply perturbation to, by default None. """ self.update_status(f'Batch Perturbation mode: {len(target_genes)} genes') progress_bar = self.manager.counter( total=len(target_genes), desc=f'Batch Perturbations', unit='genes', color='orange', autorefresh=True, ) os.makedirs(save_to, exist_ok=True) for target in target_genes: progress_bar.desc = f'Batch Perturbation - {target}' progress_bar.refresh() gex_out =self.perturb( target=target, n_propagation=n_propagation, save_layer=False, gene_expr=gene_expr, cells=cells, ) progress_bar.update() if save_to is not None: file_name = f'{target}_{n_propagation}n_{gene_expr}x' gex_out.to_parquet(f'{save_to}/{file_name}.parquet') self.update_status('Batch Perturbation: Done') progress_bar.close()
@property def possible_targets(self): return list(set.union( self.beta_dict.receptors_set, self.beta_dict.ligands_set, self.beta_dict.tfs_set ))
[docs] def genome_screen( self, save_to, n_propagation=4, priority_genes=None, mode='knockout', cells=None): """ Perform a genome-wide perturbation screen (knockout or overexpression). Iterates through all possible targets (TFs, ligands, receptors) and performs the specified perturbation, saving the results to disk. Parameters ---------- save_to : str Directory to save the results. n_propagation : int, optional Number of propagation steps, by default 4. priority_genes : list, optional List of genes to prioritize in the screen, by default None. mode : str, optional 'knockout' or 'overexpress', by default 'knockout'. cells : list, optional List of cell indices to restrict perturbation to, by default None. """ assert mode in ['knockout', 'overexpress'] if priority_genes is not None: priority_genes = list(np.intersect1d(priority_genes, self.possible_targets)) screen_queue = OracleQueue( save_to, all_genes=self.possible_targets, priority_genes=priority_genes, lock_timeout=3600 ) _manager = enlighten.get_manager() gene_bar = _manager.counter( total=len(screen_queue.all_genes), desc=f'... initializing ...', unit='genes', color='orange', autorefresh=True, ) screen_queue.kill_old_locks() max_expr = self.adata.to_df(layer='imputed_count').max().to_dict() while not screen_queue.is_empty: target = next(screen_queue) gene_bar.count = len(screen_queue.all_genes) - len(screen_queue.remaining_genes) gene_bar.desc = f'🕵️️ {screen_queue.agents+1} agents' gene_bar.refresh() if os.path.exists(f'{screen_queue.model_dir}/{target}.lock'): print(f'Found duplicate lock for {target} - skipping') continue screen_queue.create_lock(target) gex_out = self.perturb( target=target, n_propagation=n_propagation, gene_expr=0 if mode == 'knockout' else max_expr[target], cells=cells, delta_dir=None ) screen_queue.delete_lock(target) if screen_queue.last_refresh_age() > screen_queue.lock_timeout: screen_queue.kill_old_locks() screen_queue.last_refresh_on = datetime.datetime.now() gene_bar.update() # suffix = '0x' if mode == 'knockout' else f'{round(max_expr[target], 2)}x' suffix = '0x' if mode == 'knockout' else 'maxx' file_name = f'{target}_{n_propagation}n_{suffix}' gex_out.to_parquet( f'{save_to}/{file_name}.parquet')