Source code for spatialtis.plotting.interaction

from collections import Counter

import numpy as np
import pandas as pd
from anndata import AnnData
from itertools import product
from matplotlib.colors import CenteredNorm
from milkviz import anno_clustermap, dot_heatmap
from natsort import natsorted
from scipy.stats import pearsonr
from typing import List

from spatialtis import get_result
from spatialtis.utils import doc, log_print, df2adata_uns
from .utils import pairs_to_adj


def count_size_side(pdata, type_order, groupby_keys, value_key):
    dot_size, dot_hue = {}, {}
    for comb, df in pdata.groupby(groupby_keys):
        count = {1: 0, 0: 0, -1: 0, **Counter(df[value_key])}
        v = [1, -1]
        arr = [count[i] for i in v]
        sig_count = np.sum(arr)
        if sig_count == 0:
            sig_num = 0
        else:
            norm = arr / sig_count
            sig_num = v[np.argmax(norm)] * np.amax(norm)
        dot_size[comb] = sig_count
        dot_hue[comb] = sig_num
    dot_size = pairs_to_adj(
        pd.DataFrame(dot_size, index=[0]).T.reset_index(), type_order
    )
    dot_hue = pairs_to_adj(pd.DataFrame(dot_hue, index=[0]).T.reset_index(), type_order)
    return dot_size, dot_hue


def count_size_side_for_enrichment(pdata, type_order):
    dot_size, dot_hue = {}, {}
    for comb, arr in pdata.items():
        count = {1: 0, 0: 0, -1: 0, **Counter(arr)}
        v = [1, -1]
        arr = [count[i] for i in v]
        sig_count = np.sum(arr)
        if sig_count == 0:
            sig_num = 0
        else:
            norm = arr / sig_count
            sig_num = v[np.argmax(norm)] * np.amax(norm)
        dot_size[comb] = sig_count
        dot_hue[comb] = sig_num
    dot_size = pairs_to_adj(
        pd.DataFrame(dot_size, index=[0]).T.reset_index(), type_order
    )
    dot_hue = pairs_to_adj(pd.DataFrame(dot_hue, index=[0]).T.reset_index(), type_order)
    return dot_size, dot_hue


[docs]@doc def cell_interaction( data: AnnData, use: str = "dot", groupby: List = None, key: str = "cell_interaction", type_order: List[str] = None, order: bool = True, plot_value: str = "relationship", **plot_options, ): """Visualization of the cell interaction analysis Parameters ---------- data : {adata_plotting} use : {'dot', 'heatmap'}, default: 'dot' groupby : {groupby} key : {plot_key} type_order: {type_order} order : bool plot_value : {'relationship', 'statistic'} **plot_options : Pass to :func:`milkviz.dot_heatmap` or :func:`milkviz.anno_clustermap`. """ if use == "heatmap": pdata = get_result(data, key) uni_types = pd.unique(pdata[['type1', 'type2']].to_numpy().flatten()) if type_order is None: type_order = natsorted(uni_types) pdata = pdata.pivot_table(columns=['type1', 'type2'], values=plot_value, # the index of [1::] is to remove the index columns index=pdata.index.names[1::], fill_value=0) pdata = pdata[[tuple(c) for c in product(type_order, repeat=2)]] if plot_value == "relationship": options = dict( categorical_cbar=["Avoidance", "Association"], col_legend_split=False, col_legend_title="Cell type", cbar_title="Interaction", col_cluster=False, method="ward", vmin=-1, vmax=1, ) options = {**options, **plot_options} return anno_clustermap( pdata, col_colors=["type1", "type2"], row_colors=groupby, **options ) else: options = dict( col_legend_split=False, col_legend_title="Cell type", cbar_title="Interaction", col_cluster=False, method="ward", ) options = {**options, **plot_options} return anno_clustermap( pdata, col_colors=["type1", "type2"], row_colors=groupby, **options ) else: def _get_cell_components(): try: matrix = get_result(data, "cell_components") combs = {} for i1, i2 in product(matrix.items(), repeat=2): combs[(i1[0], i2[0])] = pearsonr(i1[1], i2[1])[0] matrix = pd.DataFrame(combs, index=[0]).T.reset_index() matrix = pairs_to_adj(matrix, type_order) matrix = matrix.loc[ dot_size.index, dot_size.columns ] # ensure the number match to data except Exception: log_print( "Run spatialtis.cell_components to add " "pearson's R of cell proportion to the visualization" ) matrix = None return matrix pdata = get_result(data, key) uni_types = pd.unique(pdata[['type1', 'type2']].to_numpy().flatten()) if type_order is None: type_order = natsorted(uni_types) dot_size, dot_hue = count_size_side(pdata, type_order, groupby_keys=['type1', 'type2'], value_key="relationship") matrix = _get_cell_components() if matrix is None: matrix_hue = None incomplete = True else: matrix_hue = matrix.to_numpy() dot_size_data = dot_size.to_numpy(dtype=int) dot_hue_data = dot_hue.to_numpy() xticklabels = dot_size.columns yticklabels = dot_size.index if not order: dot_size_data = np.ma.masked_values(np.tril(dot_size_data), 0) dot_hue_data = np.ma.masked_values(np.tril(dot_hue_data), 0) matrix_hue = np.ma.masked_values(np.tril(matrix_hue), 0)\ if matrix is not None else None xticklabels = xticklabels[::-1] options = dict( dot_size_legend_kw={"title": "Sign' ROI"}, dot_hue_cbar_kw={"title": "Interaction"}, matrix_cbar_kw={"title": "Pearson-R"}, sizes=(0, 250), dot_cmap="bwr", matrix_cmap="BrBG" ) options = {**options, **plot_options} return dot_heatmap( dot_size=dot_size_data, dot_hue=dot_hue_data, matrix_hue=matrix_hue, xticklabels=xticklabels, yticklabels=yticklabels, dot_norm=CenteredNorm(vcenter=0, halfrange=1.), matrix_norm=CenteredNorm(vcenter=0, halfrange=1.), **options, )
[docs]@doc def spatial_enrichment( data: AnnData, key: str = "spatial_enrichment", type_order: List[str] = None, **plot_options, ): """Visualization of the spatial enrichment analysis Parameters ---------- data : {adata_plotting} key : {plot_key} type_order : {type_order} **plot_options : Pass to :func:`milkviz.dot_heatmap`. """ store_key = "spatial_enrichment_dot" if store_key in data.uns_keys(): pdata = data.uns[store_key] dot_size_data = pdata['dot_size_data'] dot_hue_data = pdata['dot_hue_data'] xticklabels = pdata['xticklabels'] yticklabels = pdata['yticklabels'] else: rdata = get_result(data, key) dot_size, dot_hue = count_size_side_for_enrichment(rdata, type_order) dot_size_data = dot_size.to_numpy(dtype=int) dot_hue_data = dot_hue.to_numpy() xticklabels = dot_size.columns yticklabels = dot_size.index pdata = dict( dot_size_data=dot_size_data, dot_hue_data=dot_hue_data, xticklabels=xticklabels, yticklabels=yticklabels ) data.uns[store_key] = pdata # if not order: # dot_size_data = mask_triu(dot_size_data) # dot_hue_data = mask_triu(dot_hue_data) # xticklabels = xticklabels[::-1] options = dict( dot_size_legend_kw={"title": "Sign' ROI"}, dot_hue_cbar_kw={"title": "Interaction"}, matrix_cbar_kw={"title": "Pearson-R"}, sizes=(0, 250), dot_cmap="PiYG_r", **plot_options, ) return dot_heatmap( dot_size=dot_size_data, dot_hue=dot_hue_data, xticklabels=xticklabels, yticklabels=yticklabels, **options, )