Source code for SpaceTravLR.models.spatial_map

from numba import jit, prange
from tqdm import tqdm
import numpy as np
from scipy.ndimage import gaussian_filter

from ..tools.utils import deprecated

@jit
def generate_grid_centers(m, n, xmin, xmax, ymin, ymax):
    centers = []
    cell_width = (xmax - xmin) / n
    cell_height = (ymax - ymin) / m
    
    for i in range(m):
        for j in range(n):
            x = xmin + (j + 0.5) * cell_width
            y = ymax - (i + 0.5) * cell_height
            centers.append((x, y))    
    return centers

@jit
def distance(point1, point2):
    x1, y1 = point1
    x2, y2 = point2
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

# @deprecated('`xyc2spatial` is deprecated. Use `xyc2spatial_fast` instead to save the trees 🌴️')
def xyc2spatial(x, y, c, m, n, split_channels=True, disable_tqdm=True):
    
    assert len(x) == len(y) == len(c)
    xmin, xmax, ymin, ymax = np.min(x), np.max(x), np.min(y), np.max(y)
    xyc = np.column_stack([x, y, c]).astype(float)
    
    centers = generate_grid_centers(m, n, xmin, xmax, ymin, ymax)
    clusters = np.unique(c).astype(int)
    
    spatial_maps = np.zeros((len(x), m, n))
    mask = np.zeros((len(clusters), m, n))

    # mask = np.ones((len(clusters), m, n)) * -1*np.inf


    with tqdm(total=len(xyc), disable=disable_tqdm, desc=f'🌍️ Generating {m}x{n} spatial maps') as pbar:
        
        for s, coord in enumerate(xyc):
            x_, y_, cluster = coord
            
            dist_map = np.array([np.float32(distance((x_, y_), c)) for c in centers]).reshape(m, n).astype(np.float32)
            
            nearest_center_idx = np.argmin(dist_map)
            u, v = np.unravel_index(nearest_center_idx, (m, n))
            mask[int(cluster)][u, v] = 1

            spatial_maps[s] = dist_map
            
            pbar.update()
    
    
    spatial_maps = np.repeat(np.expand_dims(spatial_maps, axis=1), len(clusters), axis=1)
    mask = np.repeat(np.expand_dims(mask, axis=0), spatial_maps.shape[0], axis=0)


    # max_vals = np.max(spatial_maps, axis=(2, 3), keepdims=True)
    # channel_wise_maps = max_vals/spatial_maps*mask 
    channel_wise_maps = spatial_maps*mask 


    # channel_wise_maps = 1.0/channel_wise_maps

    # mean = np.mean(channel_wise_maps, axis=(2, 3), keepdims=True)
    # std = np.std(channel_wise_maps, axis=(2, 3), keepdims=True)
    # epsilon = 1e-8 
    # channel_wise_maps_norm = (channel_wise_maps - mean) / (std + epsilon)


    min_vals = np.min(channel_wise_maps, axis=(2, 3), keepdims=True)
    max_vals = np.max(channel_wise_maps, axis=(2, 3), keepdims=True)
    denominator = np.maximum(max_vals - min_vals, 1e-15)
    channel_wise_maps_norm = (channel_wise_maps - min_vals) / denominator

    # channel_wise_maps = (1+(channel_wise_maps_norm*-1)) * mask


    # channel_wise_maps = channel_wise_maps_norm


    # channel_wise_maps = channel_wise_maps_norm
    # channel_wise_maps = 1.0/channel_wise_maps

    # channel_wise_maps = np.where(mask!=0, channel_wise_maps_norm, channel_wise_maps_norm.max())

    # channel_wise_maps = 1.0/(channel_wise_maps+1)

    # channel_wise_maps = (1.0/spatial_maps)*mask
    # channel_wise_maps = (spatial_maps.max()/spatial_maps)*mask  
    # channel_wise_maps = gaussian_filter(channel_wise_maps, sigma=0.5)
        
    assert channel_wise_maps.shape == (len(x), len(clusters), m, n)
    
    if split_channels:
        return channel_wise_maps
    else:
        return channel_wise_maps.sum(axis=1)
    
    
[docs] @jit(nopython=True, parallel=True) def xyc2spatial_fast(xyc, m, n): """ Converts spatial coordinates (x, y) and cluster labels (c) to a spatial \ distance map with grid sizes m x n. Each channels encodes the distance map for a unique cluster. Return (n_samples, n_clusters, m, n) """ # print(f'🌍️ Generating spatial {m}x{n} maps...*') x, y, c = xyc[:, 0], xyc[:, 1], xyc[:, 2] xmin, xmax, ymin, ymax = np.min(x), np.max(x), np.min(y), np.max(y) centers = generate_grid_centers(m, n, xmin, xmax, ymin, ymax) clusters = np.unique(c).astype(np.int32) num_clusters = len(clusters) spatial_maps = np.zeros((len(xyc), num_clusters, m, n), dtype=np.float32) # mask = np.zeros((num_clusters, m, n), dtype=np.float32) mask = np.ones((num_clusters, m, n), dtype=np.float32) for s in prange(len(xyc)): x_, y_, cluster = xyc[s] dist_map = np.array([distance((x_, y_), c) for c in centers]).reshape(m, n) nearest_center_idx = np.argmin(dist_map) u, v = nearest_center_idx // n, nearest_center_idx % n mask[int(cluster), u, v] = 1 for i in range(num_clusters): spatial_maps[s, i] = dist_map max_val = np.max(spatial_maps) channel_wise_maps = np.zeros_like(spatial_maps) for s in prange(len(xyc)): for i in range(num_clusters): for j in range(m): for k in range(n): # channel_wise_maps[s, i, j, k] = (max_val / spatial_maps[s, i, j, k]) * mask[i, j, k] channel_wise_maps[s, i, j, k] = spatial_maps[s, i, j, k] * mask[i, j, k] min_vals = np.zeros((len(xyc), num_clusters, 1, 1), dtype=np.float32) max_vals = np.zeros((len(xyc), num_clusters, 1, 1), dtype=np.float32) for s in prange(len(xyc)): for i in range(num_clusters): min_vals[s, i, 0, 0] = np.min(channel_wise_maps[s, i]) max_vals[s, i, 0, 0] = np.max(channel_wise_maps[s, i]) denominator = np.maximum(max_vals - min_vals, 1e-15) channel_wise_maps_norm = np.zeros_like(channel_wise_maps) for s in prange(len(xyc)): for i in range(num_clusters): for j in range(m): for k in range(n): channel_wise_maps_norm[s, i, j, k] = (channel_wise_maps[s, i, j, k] - min_vals[s, i, 0, 0]) / denominator[s, i, 0, 0] # channel_wise_maps = 1.0/channel_wise_maps return channel_wise_maps_norm