Source code for SpaceTravLR.oracles

from abc import ABC
from easydict import EasyDict
import numpy as np
import enlighten
import time
import pandas as pd
import pickle
import torch
from tqdm import tqdm
import os
import datetime
import re
import glob
import pickle
import io
import json
import warnings
from sklearn.decomposition import PCA
import warnings
from sklearn.linear_model import Ridge
from sklearn.neighbors import NearestNeighbors

from .tools.network import DayThreeRegulatoryNetwork
from .tools.knn_smooth import knn_smoothing
from .tools.utils import deprecated
from .models.spatial_map import xyc2spatial, xyc2spatial_fast
from .models.parallel_estimators import SpatialCellularProgramsEstimator

from .tools.utils import (
    clean_up_adata,
    knn_distance_matrix,
    _adata_to_matrix,
    connectivity_to_weights,
    convolve_by_sparse_weights,
    prune_neighbors
)

import warnings
warnings.filterwarnings("ignore")


class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)

class BaseTravLR(ABC):
    
    def __init__(self, adata, fields_to_keep=['cell_type', 'cell_type_int', 'cell_thresholds']):
        assert 'normalized_count' in adata.layers
        
        self.settings = EasyDict()
        
        self.adata = adata.copy()
        # self.adata.layers['normalized_count'] = self.adata.X.copy()
        self.gene2index = dict(zip(self.adata.var_names, range(len(self.adata.var_names))))
        self.pcs = None
        
        if 'imputed_count' not in self.adata.layers:
            self.pcs = self.perform_PCA(self.adata)
            self.knn_imputation(self.adata, self.pcs, method='MAGIC')

        clean_up_adata(self.adata, fields_to_keep=fields_to_keep)

    ## cannibalized from CellOracle
    @staticmethod
    def perform_PCA(adata, n_components=None, div_by_std=False):
        X = _adata_to_matrix(adata, "normalized_count")

        pca = PCA(n_components=n_components)
        if div_by_std:
            pcs = pca.fit_transform(X.T / X.std(0))
        else:
            pcs = pca.fit_transform(X.T)
        
        n_comps = np.where(np.diff(np.diff(np.cumsum(pca.explained_variance_ratio_))>0.002))[0][0]

        return pcs[:, :n_comps]
    
    
    @staticmethod
    def impute_clusterwise(adata, annot='cell_type', layer='normalized_count', layer_added='imputed_count'):
        import magic
        import warnings
        import enlighten
        warnings.filterwarnings("ignore")
        
        X = _adata_to_matrix(adata, layer)
        X = X.T
        X = pd.DataFrame(X, columns=adata.var_names, index=adata.obs_names)

        X_magic_list = []
        pbar = enlighten.get_manager().counter(
            total=len(adata.obs[annot].unique()),
            desc='Imputing clusterwise',
            unit='clusters',
            color='green',
            auto_refresh=True
        )

        for cell_type in adata.obs[annot].unique():
            magic_operator = magic.MAGIC(verbose=0)
            
            mask = adata.obs[annot] == cell_type
            X_subset = X.loc[mask]
            X_magic_subset = magic_operator.fit_transform(X_subset, genes='all_genes')
            X_magic_list.append(X_magic_subset)
            pbar.update()
            
        X_magic = pd.concat(X_magic_list)
        X_magic = X_magic.loc[adata.obs_names]
        
        adata.layers[layer_added] = X_magic.values

    @staticmethod
    @deprecated(instructions="Use impute_clusterwise instead")
    def knn_imputation(adata, pcs, k=None, metric="euclidean", diag=1,
                       n_pca_dims=50, maximum=False,
                       balanced=True, b_sight=None, b_maxl=None,
                       method='MAGIC', n_jobs=8) -> None:
        
        supported_methods = ['CellOracle', 'MAGIC', 'knn-smoothing']
        assert method in supported_methods, f'method is not implemented, choose from {supported_methods}'
        
        X = _adata_to_matrix(adata, "normalized_count")

        N = adata.shape[0] # cell number

        if k is None:
            k = int(N * 0.025)
        if b_sight is None and balanced:
            b_sight = int(k * 8)
        if b_maxl is None and balanced:
            b_maxl = int(k * 4)

        n_pca_dims = min(n_pca_dims, pcs.shape[1])
        space = pcs[:, :n_pca_dims]

        if method == 'CellOracle':
            if balanced:
                nn = NearestNeighbors(n_neighbors=b_sight + 1, metric=metric, n_jobs=n_jobs, leaf_size=30)
                nn.fit(space)

                dist, dsi = nn.kneighbors(space, return_distance=True)
                knn = prune_neighbors(dsi, dist, b_maxl)
            
            else:
                knn = knn_distance_matrix(space, metric=metric, k=k,
                                                mode="distance", n_jobs=n_jobs)
            
            connectivity = (knn > 0).astype(float)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                connectivity.setdiag(diag)
            knn_smoothing_w = connectivity_to_weights(connectivity)

            Xx = convolve_by_sparse_weights(X, knn_smoothing_w)
            adata.layers["imputed_count"] = Xx.transpose().copy()
            
        elif method == 'MAGIC':
            import magic
            
            X = X.T
            magic_operator = magic.MAGIC()
            X = pd.DataFrame(X, columns=adata.var_names, index=adata.obs_names)
            X_magic = magic_operator.fit_transform(X, genes='all_genes')

            adata.layers['imputed_count'] = X_magic
        
        elif method == 'knn-smoothing':

            d = 10          # n pcs default 10
            dither = 0.03   # default 0.03 
            k = 32          # number of neighbors 

            matrix = adata.layers['raw_count'].T 
            S = knn_smoothing(matrix, k, d=d, dither=dither, seed=1334)

            adata.layers['imputed_count'] = S.T

            

[docs] class OracleQueue: """ A jobs manager for training gene models in parallel Ideal for HPC environments """
[docs] def __init__(self, model_dir, all_genes, priority_genes=None,lock_timeout=3600): """ Args: model_dir (str): Directory to store the trained model weights. all_genes (list): List of all genes to train models for. priority_genes (list, optional): List of genes to train models for first. lock_timeout (int, optional): Timeout for job locks in seconds. """ if not os.path.exists(model_dir): os.makedirs(model_dir) self.model_dir = model_dir self.all_genes = all_genes self.orphans = [] self.lock_timeout = lock_timeout self.priority_genes = priority_genes self.created_on = datetime.datetime.now() self.last_refresh_on = datetime.datetime.now()
@property def age(self): """Return the age of the queue in seconds.""" return (datetime.datetime.now() - self.created_on).total_seconds()
[docs] def last_refresh_age(self): """Return the age of the last refresh in seconds.""" return (datetime.datetime.now() - self.last_refresh_on).total_seconds()
@property def regulated_genes(self): if not self.orphans: return self.all_genes return list(set(self.all_genes).difference(set(self.orphans))) def __getitem__(self, index): return self.remaining_genes[index] def __iter__(self): return self def __next__(self): """ Get a gene to train for """ if self.is_empty: raise StopIteration if self.priority_genes is not None: self.priority_genes = np.intersect1d(self.priority_genes, self.remaining_genes) if len(self.priority_genes) > 0: return self.priority_genes[0] return np.random.choice(self.remaining_genes) def __len__(self): return len(self.remaining_genes) @property def is_empty(self): return self.__len__() == 0 @property def completed_genes(self): completed_paths = glob.glob(f'{self.model_dir}/*.parquet') return list(filter(None, map(self.extract_gene_name, completed_paths))) @property def num_orphans(self): return len(glob.glob(f'{self.model_dir}/*.orphan')) @property def agents(self): return len(glob.glob(f'{self.model_dir}/*.lock')) @property def remaining_genes(self): completed_paths = glob.glob(f'{self.model_dir}/*.parquet') locked_paths = glob.glob(f'{self.model_dir}/*.lock') orphan_paths = glob.glob(f'{self.model_dir}/*.orphan') completed_genes = list(filter(None, map(self.extract_gene_name, completed_paths))) locked_genes = list(filter(None, map(self.extract_gene_name, locked_paths))) orphan_genes = list(filter(None, map(self.extract_gene_name, orphan_paths))) return list(set(self.regulated_genes).difference( set(completed_genes+locked_genes+orphan_genes)))
[docs] def create_lock(self, gene): """ Create a lock for a gene while the model is being trained. This prevents multiple processes from training the same gene. """ now = str(datetime.datetime.now()) pid = os.getpid() with open(f'{self.model_dir}/{gene}.lock', 'w') as f: f.write(f'{now} {pid}')
[docs] def delete_lock(self, gene): try: assert os.path.exists(f'{self.model_dir}/{gene}.lock') os.remove(f'{self.model_dir}/{gene}.lock') except Exception as e: print(f'Error deleting lock for {gene}: {e}')
[docs] def kill_old_locks(self): locked_paths = glob.glob(f'{self.model_dir}/*.lock') for path in locked_paths: try: with open(path, 'r') as f: data = f.read() lock = datetime.datetime.strptime( ' '.join(data.split()[:2]), "%Y-%m-%d %H:%M:%S.%f") if (datetime.datetime.now() - lock).total_seconds() > self.lock_timeout: gene = self.extract_gene_name(path) self.delete_lock(gene) print(f'Deleted lock for {gene} after {self.lock_timeout} seconds') except Exception as e: print(f'Error deleting lock for {path}: {e}')
[docs] def add_orphan(self, gene): """ Creates a .orphan file to mark a gene as untrainable. Used when a gene-gene network is too sparse to train a model. For example when a gene has no known TFs. """ now = str(datetime.datetime.now()) pid = os.getpid() with open(f'{self.model_dir}/{gene}.orphan', 'w') as f: f.write(f'{now} {pid}') self.orphans.append(gene)
[docs] @staticmethod def extract_gene_name(path): patterns = { 'betadata': r'([^/]+)_betadata\.parquet$', 'lock': r'([^/]+)\.lock$', 'orphan': r'([^/]+)\.orphan$', 'perturbed': r'([^/]+)_\d+n_[\d\.]+x\.parquet$', 'maxx': r'([^/]+)_\d+n_maxx\.parquet$' } for pattern in patterns.values(): match = re.search(pattern, path) if match: return match.group(1) return None
def __str__(self): return f'OracleQueue with {len(self.remaining_genes)} remaining genes' def __repr__(self): return self.__str__()
[docs] class SpaceTravLR(BaseTravLR):
[docs] def __init__( self, adata, save_dir='./models', annot='cell_type_int', grn=None, max_epochs=50, spatial_dim=64, learning_rate=5e-3, batch_size=512, rotate_maps=True, layer='imputed_count', alpha=0.05, threshold_lambda=1e-6, tflinks=None, tf_ligand_cutoff=0.01, radius=200, contact_distance=30, skip_clusters=None, scale_factor=1): super().__init__(adata, fields_to_keep=[annot, 'cell_thresholds']) if grn is None: self.grn = DayThreeRegulatoryNetwork() # CellOracle GRN else: self.grn = grn self.save_dir = save_dir self.queue = OracleQueue(save_dir, all_genes=self.adata.var_names) self.annot = annot self.max_epochs = max_epochs self.spatial_dim = spatial_dim self.learning_rate = learning_rate self.batch_size = batch_size self.rotate_maps = rotate_maps self.layer = layer self.alpha = alpha self.threshold_lambda = threshold_lambda self.tf_ligand_cutoff = tf_ligand_cutoff self.beta_dict = None self.coef_matrix = None self.radius = radius self.contact_distance = contact_distance self.scale_factor = scale_factor self.estimator_models = {} self.ligands = set() self.genes = list(self.adata.var_names) self.trained_genes = [] self.skip_clusters = skip_clusters if not os.path.exists(self.save_dir+'/run_params.json'): with open(self.save_dir+'/run_params.json', 'w') as f: json.dump({ 'timestamp': str(datetime.datetime.now()), 'max_epochs': max_epochs, 'spatial_dim': spatial_dim, 'learning_rate': learning_rate, 'batch_size': batch_size, 'rotate_maps': rotate_maps, 'threshold_lambda': threshold_lambda, 'tf_ligand_cutoff': tf_ligand_cutoff, 'radius': radius, 'contact_distance': contact_distance, 'annot': annot, 'layer': layer, 'save_dir': save_dir, 'n_genes': len(self.genes), 'scale_factor': scale_factor }, f, indent=4)
[docs] def run(self): _manager = enlighten.get_manager() gene_bar = _manager.counter( total=len(self.queue.all_genes), desc=f'... initializing ...', unit='genes', color='green', autorefresh=True, ) train_bar = _manager.counter( total=self.adata.shape[0]*self.max_epochs, desc=f'Ready...', unit='cells', color='red', auto_refresh=True ) while not self.queue.is_empty and not os.path.exists(self.save_dir+'/process.kill'): # Remove old locks from other models self.queue.kill_old_locks() gene = next(self.queue) estimator = SpatialCellularProgramsEstimator( adata=self.adata, target_gene=gene, layer=self.layer, cluster_annot=self.annot, spatial_dim=self.spatial_dim, radius=self.radius, contact_distance=self.contact_distance, tf_ligand_cutoff=self.tf_ligand_cutoff, grn=self.grn, scale_factor=self.scale_factor ) estimator.test_mode = False if len(estimator.regulators) == 0: self.queue.add_orphan(gene) continue else: gene_bar.count = len(self.queue.all_genes) - len(self.queue.remaining_genes) gene_bar.desc = f'🕵️️ {self.queue.agents+1} workers' gene_bar.refresh() if os.path.exists(f'{self.queue.model_dir}/{gene}.lock'): continue self.queue.create_lock(gene) estimator.fit( num_epochs=self.max_epochs, threshold_lambda=self.threshold_lambda, learning_rate=self.learning_rate, batch_size=self.batch_size, pbar=train_bar, skip_clusters=self.skip_clusters ) ## filter out columns with all zeros betadata = estimator.betadata if betadata.shape[1] > 1: betadata.loc[:, (betadata != 0).any(axis=0)].to_parquet( f'{self.save_dir}/{gene}_betadata.parquet') else: self.queue.add_orphan(gene) self.trained_genes.append(gene) self.queue.delete_lock(gene) if self.queue.last_refresh_age() > self.queue.lock_timeout: self.queue.kill_old_locks() self.queue.last_refresh_on = datetime.datetime.now() gene_bar.count = len(self.queue.all_genes) - len(self.queue.remaining_genes) gene_bar.refresh() train_bar.count = 0 train_bar.start = time.time() gene_bar.desc = 'All done! 🎉️' gene_bar.refresh()
[docs] @staticmethod def imbue_adata_with_space(adata, annot='cell_type_int', spatial_dim=64, in_place=False, method='fast'): """ Generate and cache 2D spatial maps for each cell location """ clusters = np.array(adata.obs[annot]) xy = np.array(adata.obsm['spatial']) if method == 'fast': sp_maps = xyc2spatial_fast( xyc = np.column_stack([xy, clusters]), m=spatial_dim, n=spatial_dim, ).astype(np.float32) else: sp_maps = xyc2spatial( xy[:, 0], xy[:, 1], clusters, spatial_dim, spatial_dim, disable_tqdm=False ).astype(np.float32) if in_place: adata.obsm['spatial_maps'] = sp_maps return return sp_maps