from numba import jit, njit, prange
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import torch
import random
import functools
import inspect
import warnings
import pickle
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix
from scipy import sparse
from tqdm import tqdm
import io
import networkx as nx
from sklearn.neighbors import NearestNeighbors
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else:
return super().find_class(module, name)
def search(query, string_list):
return [i for i in string_list if query.lower() in i.lower()]
def scale_adata(adata, cell_size=15):
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(adata.obsm['spatial'])
distances, indices = nbrs.kneighbors(adata.obsm['spatial'])
# nn_distance = np.percentile(distances[:, 1], 5).min() # maybe 5% cells are squished
nn_distance = np.median(distances[:, 1])
scale_factor = cell_size / nn_distance
adata.obsm['spatial_unscaled'] = adata.obsm['spatial'].copy()
adata.obsm['spatial'] *= scale_factor
return adata
[docs]
def knn_distance_matrix(data, metric=None, k=40, mode='connectivity', n_jobs=4):
"""Calculate a nearest neighbour distance matrix
Notice that k is meant as the actual number of neighbors NOT INCLUDING itself
To achieve that we call kneighbors_graph with X = None
"""
if metric == "correlation":
nn = NearestNeighbors(
n_neighbors=k, metric="correlation",
algorithm="brute", n_jobs=n_jobs)
nn.fit(data)
return nn.kneighbors_graph(X=None, mode=mode)
else:
nn = NearestNeighbors(n_neighbors=k, n_jobs=n_jobs, )
nn.fit(data)
return nn.kneighbors_graph(X=None, mode=mode)
[docs]
def connectivity_to_weights(mknn, axis=1):
if type(mknn) is not sparse.csr_matrix:
mknn = mknn.tocsr()
return mknn.multiply(1. / sparse.csr_matrix.sum(mknn, axis=axis))
[docs]
def convolve_by_sparse_weights(data, w):
w_ = w.T
assert np.allclose(w_.sum(0), 1)
return sparse.csr_matrix.dot(data, w_)
def _adata_to_matrix(adata, layer_name, transpose=True):
if isinstance(adata.layers[layer_name], np.ndarray):
matrix = adata.layers[layer_name].copy()
else:
matrix = adata.layers[layer_name].todense().A.copy()
if transpose:
matrix = matrix.transpose()
return matrix.copy(order="C")
class DeprecatedWarning(UserWarning):
pass
def deprecated(instructions=''):
"""Flags a method as deprecated.
Args:
instructions: A human-friendly string of instructions, such
as: 'Please migrate to add_proxy() ASAP.'
"""
def decorator(func):
'''This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.'''
@functools.wraps(func)
def wrapper(*args, **kwargs):
message = '{} is a deprecated function. {}'.format(
func.__name__,
instructions)
frame = inspect.currentframe().f_back
warnings.warn_explicit(message,
category=DeprecatedWarning,
filename=inspect.getfile(frame.f_code),
lineno=frame.f_lineno)
return func(*args, **kwargs)
return wrapper
return decorator
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
[docs]
def clean_up_adata(adata, fields_to_keep):
current_obs_fields = adata.obs.columns.tolist()
excess_obs_fields = [field for field in current_obs_fields if field not in fields_to_keep]
for field in excess_obs_fields:
del adata.obs[field]
current_var_fields = adata.var.columns.tolist()
excess_var_fields = [field for field in current_var_fields
if field not in []]
for field in excess_var_fields:
del adata.var[field]
for field in set(adata.uns.keys()) - set(fields_to_keep):
del adata.uns[field]
[docs]
@jit
def gaussian_kernel_2d(origin, xy_array, radius, eps=0.001):
"""
Compute a 2D Gaussian kernel weights for a given origin on a given grid.
Args:
origin (np.ndarray): The center point for the Gaussian kernel.
xy_array (np.ndarray): Array of points to compute weights for.
radius (float): The radius of the Gaussian kernel.
eps (float): A small epsilon value to prevent division by zero.
Returns:
np.ndarray: Array of Gaussian weights.
"""
distances = np.sqrt(np.sum((xy_array - origin)**2, axis=1))
sigma = radius / np.sqrt(-2 * np.log(eps))
weights = np.exp(-(distances**2) / (2 * sigma**2))
weights[distances > radius] = 0
# weights[0] = 0
return weights
def min_max_df(df):
return pd.DataFrame(
MinMaxScaler().fit_transform(df),
columns=df.columns,
index=df.index
)
def prune_neighbors(dsi, dist, maxl):
num_samples = dsi.shape[0]
rows = np.repeat(np.arange(num_samples), dsi.shape[1])
cols = dsi.flatten()
weights = dist.flatten()
adjacency = np.zeros((num_samples, num_samples), dtype=weights.dtype)
adjacency[rows, cols] = weights
np.fill_diagonal(adjacency, 0)
for i in range(num_samples):
row = adjacency[i]
non_zero_indices = np.nonzero(row)[0]
if len(non_zero_indices) > maxl:
sorted_indices = non_zero_indices[np.argsort(row[non_zero_indices])] # indices sorted by weight
to_remove = sorted_indices[maxl:] # set all connections with high weight to 0
adjacency[i, to_remove] = 0
adjacency = np.minimum(adjacency, adjacency.T)
bknn = csr_matrix(adjacency)
return bknn
def lR_to_l(adata, mapper={'leiden_R': 'leiden'}):
'''
Map a current column name to a new column name. By default,
maps `leiden_R` to `leiden`, typically run after using
`sc.tl.leiden(restrict_to=)`.
`adata`: annotated data matrix
returns: None, modifies in-place
'''
for current_col_name in mapper:
new_col_name = mapper[current_col_name]
current_col = adata.obs[current_col_name].copy()
adata.obs.drop(columns=current_col_name)
adata.obs[new_col_name] = current_col
return
def reset_colors(adata, key='leiden', use_plt=True):
if use_plt:
try:
del(adata.uns['plt']['color'][key])
except:
pass
else:
# Fall back to scanpy color storage
try:
del(adata.uns['%s_colors' % key])
except:
pass
return
def relabel_clusts(adata, key='leiden'):
'''
Relabel the values in `key` as ordered categories numbering from 0 to _n_.
`adata`: annotated data matrix
`key`: name of column in `adata.obs` with the clusters
returns: None, modifies in-place
'''
try:
adata.obs[key].cat
except AttributeError:
adata.obs[key] = adata.obs[key].astype('category')
cats = adata.obs[key].cat.categories
new_cats = [str(i) for i in range(len(cats))]
adata.obs[key] = adata.obs[key].map(dict(zip(cats, new_cats)))
adata.obs[key] = adata.obs[key].astype('category')
reset_colors(adata, key=key)
return
def clean_leiden(adata):
'''
Convenience function to clean up the `leiden` column in `adata.obs`.
`adata`: annotated data matrix
'''
lR_to_l(adata)
relabel_clusts(adata)
[docs]
def is_mouse_data(adata):
"""
Determine if an AnnData object contains mouse or human data based on gene names.
This function examines gene names to determine if the data is from mouse (capitalized first letter only)
or human (all caps gene symbols). It samples a subset of genes to make the determination.
Parameters
----------
adata : AnnData
The annotated data matrix to check
Returns
-------
bool
True if the data appears to be from mouse, False if it appears to be from human
"""
# Get a sample of gene names to check (up to 100)
gene_sample = np.random.choice(adata.var_names, size=min(100, len(adata.var_names)), replace=False)
# Count genes that follow mouse naming convention (only first letter capitalized)
mouse_pattern_count = sum(1 for gene in gene_sample if
gene[0].isupper() and
all(not c.isupper() for c in gene[1:]) and
len(gene) > 1)
# Count genes that follow human naming convention (all uppercase)
human_pattern_count = sum(1 for gene in gene_sample if
all(c.isupper() or not c.isalpha() for c in gene) and
any(c.isupper() for c in gene))
# Return True if more genes match mouse pattern than human pattern
return mouse_pattern_count > human_pattern_count