Source code for spatialtis.plotting.regular_spatial

from collections import Counter

import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData
from milkviz import dot_heatmap, anno_clustermap
from natsort import natsorted
from typing import List

from spatialtis import get_result
from spatialtis.utils import doc


[docs]@doc def spatial_heterogeneity( data: AnnData, groupby: str = None, key: str = "heterogeneity", **plot_options, ): """Visualization of spatial heterogeneity analysis Parameters ---------- data : {adata_plotting} groupby : {groupby} key : {plot_key} **plot_options : Pass to :func:`seaborn.boxplot`. """ pdata = get_result(data, key).reset_index() exp_obs = pdata.columns[0:-1].tolist() if (len(exp_obs) == 1) | (groupby == exp_obs[0]): options = dict(color="#5DAC81", **plot_options) ax = sns.barplot(data=pdata, y="heterogeneity", x=exp_obs[-1], **options) else: groupby = exp_obs[0] if groupby is None else groupby ax = sns.boxplot(data=pdata, y="heterogeneity", x=groupby, **plot_options) return ax
[docs]@doc def cell_dispersion( data: AnnData, use: str = "dot", groupby: List[str] = None, type_order: List[str] = None, key: str = "cell_dispersion", **plot_options, ): """ Parameters ---------- data : {adata_plotting} use : {'dot', 'heatmap'} groupby : {groupby} type_order : {type_order} key : {plot_key} **plot_options : Pass to :func:`milkviz.dot` and :func:`milkviz.anno_clustermap`. """ pdata = get_result(data, key) pdata = pd.pivot_table( pdata, columns="cell_type", index=pdata.index.names[1::], values="pattern" ) if type_order is None: type_order = natsorted(pdata.columns) pdata = pdata[type_order] if use == "dot": pattern = {} for t, arr in pdata.items(): pattern[t] = {0: 0, 1: 0, 2: 0, 3: 0, **Counter(arr)} pdata = pd.DataFrame(pattern).T[[0, 1, 2, 3]] colors = np.repeat( [["#FFC408", "#c54a52", "#4a89b9", "#5a539d"]], len(pdata), axis=0 ) return dot_heatmap( dot_size=pdata.to_numpy(dtype=int), dot_hue=colors, xticklabels=["No Cell", "Random", "Regular", "Cluster"], yticklabels=pdata.index, **plot_options, ) else: pdata = pdata.rename_axis(columns={"cell_type": "Cell Type"}) pdata.fillna(0., inplace=True) plot_kw = dict( categorical_cbar=["No Cell", "Random", "Regular", "Cluster"], heat_cmap="tab20", col_legend_split=False, cbar_title="Pattern", vmin=0, vmax=3, col_cluster=False, ) plot_kw = {**plot_kw, **plot_options} return anno_clustermap( pdata, col_colors="Cell Type", row_colors=groupby, **plot_kw )