Source code for spatialtis.spatial.hotspot

from typing import Optional

import numpy as np
import pandas as pd
from anndata import AnnData
from scipy.spatial import cKDTree
from scipy.stats import norm
from shapely.geometry import MultiPoint

from spatialtis.abc import AnalysisBase
from spatialtis.spatial.utils import QuadStats, get_eval
from spatialtis.typing import Array
from spatialtis.utils import col2adata_obs, create_remote, doc, run_ray
from spatialtis.utils.log import pbar_iter


def _hotspot(cells, bbox, grid_size, level, pval):
    q = QuadStats(cells, bbox, grid_size=grid_size)

    nx = q.nx
    ny = q.ny
    N = nx * ny

    # the grid must bigger than 3 * 3 so that it can have neighbors
    if N < 9:
        return ["cold" for _ in range(0, len(cells))]
    else:
        dict_id_count = q.grid_counts()

        quad_count = np.asarray(list(dict_id_count.values())).reshape(nx, ny)
        idx_points = [(i, j) for i in range(0, nx) for j in range(0, ny)]
        hotsquad = []

        tree = cKDTree(idx_points)
        # parameter in equation
        mean_C = np.mean(quad_count)
        sum_c = np.sum(np.square(quad_count.ravel()))
        S = np.sqrt(sum_c / N - mean_C ** 2)
        # There will be some situation when S == 0
        if S == 0:
            return ["cold" for _ in range(0, len(cells))]
        else:
            for p in idx_points:
                # neighbors = tree.query_ball_point(p, r=np.sqrt(2))
                neighbors = tree.query_ball_point(p, r=level * np.sqrt(2))
                pp = [idx_points[i] for i in neighbors]
                ix = np.asarray([p[0] for p in pp])
                iy = np.asarray([p[1] for p in pp])
                sum_wc = np.sum(
                    quad_count[ix.min(): ix.max(), iy.min(): iy.max()].ravel()
                )

                sum_w = len(neighbors)

                U = np.sqrt((N * sum_w - sum_w ** 2) / (N - 1))
                # U == 0 means the neighbor cells is the entire regions
                # meaning the regions is too small so no significant hotspot
                if U == 0:
                    hotsquad.append(False)
                else:
                    # z score for region
                    z = sum_wc - (mean_C * sum_w / (S * U))
                    p_value = norm.sf(np.abs(z))
                    hot = p_value < pval
                    hotsquad.append(hot)

            marker_hot = []

            for i in q.cells_grid_id:
                if hotsquad[i]:
                    marker_hot.append("hot")
                else:
                    marker_hot.append("cold")

        return marker_hot


[docs]@doc class hotspot(AnalysisBase): """`Getis-ord hotspot detection <../about/implementation.html#hotspot-detection>`_ Used to identify cells that cluster together. Args: data: {adata} selected_types: {selected_types} search_level: How deep the search level to reach grid_size: Length of the side of square grid pval: {pval} kwargs: {analysis_kwargs} """ def __init__( self, data: AnnData, selected_types: Optional[Array] = None, search_level: int = 1, grid_size: int = 50, pval: float = 0.01, **kwargs ): super().__init__(data, task_name="hotspot", **kwargs) df = data.obs[self.exp_obs + [self.cell_type_key, self.centroid_key]] if selected_types is not None: df = df[df[self.cell_type_key].isin(selected_types)] groups = df.groupby(self.exp_obs) need_eval = self.is_col_str(self.centroid_key) hotcells = [] if self.mp: hotspot_mp = create_remote(_hotspot) jobs = [] indexs = [] for name, group in groups: for t, tg in group.groupby(self.cell_type_key): if len(tg) > 1: cells = get_eval(tg, self.centroid_key, need_eval) bbox = MultiPoint(cells).bounds jobs.append( hotspot_mp.remote( cells, bbox, grid_size, search_level, pval ) ) indexs.append(tg.index) elif len(tg) == 1: hotcells.append(pd.Series(["cold"], index=tg.index)) results = run_ray(jobs, desc="Hotspot analysis") for hots, i in zip(results, indexs): hotcells.append(pd.Series(hots, index=i)) else: for name, group in pbar_iter(groups, desc="Hotspot analysis", ): for t, tg in group.groupby(self.cell_type_key): if len(tg) > 1: cells = get_eval(tg, self.centroid_key, need_eval) bbox = MultiPoint(cells).bounds hots = _hotspot(cells, bbox, grid_size, search_level, pval) hotcells.append(pd.Series(hots, index=tg.index)) elif len(tg) == 1: hotcells.append(pd.Series(["cold"], index=tg.index)) result = pd.concat(hotcells) self.data.obs[self.export_key] = result # Cell map will leave blank if fill with None value self.data.obs[self.export_key].fillna("other", inplace=True) # Call this to invoke the print col2adata_obs(self.data.obs[self.export_key], self.data, self.export_key) self.stop_timer() @property def result(self): return self.data.obs[self.exp_obs + [self.cell_type_key, self.export_key]]