from __future__ import annotations
import numpy as np
from anndata import AnnData
from itertools import cycle
from legendkit import CatLegend
from matplotlib import pyplot as plt
from matplotlib.colors import to_hex
from milkviz import point_map, polygon_map
from typing import List, Optional, Tuple, Dict
from spatialtis.abc import AnalysisBase
from spatialtis.utils import doc
from .utils import COLOR_POOL
def _fig_layout(count, ncol):
if count <= ncol:
nrow = 1
ncol = count
else:
nrow = count // ncol + (count % ncol > 0)
return nrow, ncol
def _sep_plot_options(plot_options, hijack="legend_kw"):
# hijack the legend configuration
legend_kw = {}
if hijack in plot_options.keys():
legend_kw = plot_options[hijack]
del plot_options[hijack]
# don't allow use to set ax to prevent overlay
if 'ax' in plot_options.keys():
del plot_options['ax']
return legend_kw, plot_options
def _color_mapper(ab,
data,
masked_type_name,
masked_type_color,
types_colors=None,
selected_types=None,
):
# assign cell colors for each cell type
color_mapper = {}
legend_color_mapper = {}
store_key = "cell_colors"
unique_types = ab.cell_types
if ab.has_cell_type:
# if user specific a colormap, we use user one
# otherwise, create a new one or read from anndata
if types_colors is not None:
color_mapper = types_colors
else:
if store_key in data.uns_keys():
# alloc new to prevent modified the stored version
color_mapper = {**data.uns[store_key]}
else:
# Create a global color mapper to ensure that the cell color is the same across ROI
color_mapper = dict(zip(unique_types, cycle(COLOR_POOL)))
# alloc new to prevent modified the stored version
data.uns[store_key] = {**color_mapper}
if selected_types is not None:
unique_types = np.unique(selected_types)
for t in selected_types:
legend_color_mapper[t] = color_mapper[t]
masked_type_color = to_hex(masked_type_color, keep_alpha=True)
color_mapper[masked_type_name] = masked_type_color
legend_color_mapper[masked_type_name] = masked_type_color
else:
legend_color_mapper = {**color_mapper}
return color_mapper, legend_color_mapper, unique_types
def _masked_cell_type(cell_types, unique_types, selected_types,
masked_type_name):
if (cell_types is not None) & (selected_types is not None):
cell_mask = np.isin(cell_types, unique_types)
cell_types[~cell_mask] = masked_type_name
return cell_types
[docs]@doc
def cell_map(
data: AnnData,
rois: List[str],
ncol: int = 5,
use_shape: bool = False,
show_neighbors: bool = False,
selected_types: Optional[List] = None,
masked_type_name: str = "Other",
masked_type_color: str = "#d3d3d3",
figsize: Tuple[float, float] = None,
wspace: float = 0,
hspace: float = 0.1,
types_colors: Dict = None,
cell_type_key: str = None,
shape_key: str = None,
centroid_key: str = None,
roi_key: str = None,
**plot_options,
):
"""Visualize cells and neighbors relationship in ROI
Parameters
----------
data : {adata_plotting}
rois : list of str
A list of ROI name that you want to plot.
ncol : int
The number of columns in the figure layout.
use_shape : bool
Plot cell in polygon only when shape data is available.
show_neighbors : bool
Plot the neighbors' relationship.
selected_types : {selected_types}
masked_type_name : str, default: 'Other'
The name of the cell types not in selected_types.
masked_type_color : color-like, default: '#d3d3d3'
The color of the cell types not in selected_types.
figsize : tuple of float
The size of figure.
wspace : float, default: 0
The space between plots vertically.
hspace : float, default: 0.1
The space between plots horizontally.
types_colors : dict
Change the color for each cell type,
Key is the type and value is the color.
cell_type_key : {cell_type_key}
shape_key : {shape_key}
centroid_key : {centroid_key}
roi_key : {roi_key}
**plot_options:
Pass to :func:`milkviz.point_map` or :func:`milkviz.polygon_map`
"""
ab = AnalysisBase(data,
cell_type_key=cell_type_key,
shape_key=shape_key,
centroid_key=centroid_key,
roi_key=roi_key,
verbose=False)
if show_neighbors:
if ab.dimension == 3:
raise NotImplementedError("Does not support 3D neighbor map")
ab.check_neighbors()
ab.is_rois_name_unique()
if isinstance(rois, str):
rois = [rois]
color_mapper, legend_color_mapper, unique_types = \
_color_mapper(ab,
data,
masked_type_name,
masked_type_color,
types_colors,
selected_types
)
roi_count = len(rois)
nrow, ncol = _fig_layout(roi_count, ncol)
if figsize is None:
figsize = (ncol * 4, nrow * 4)
fig = plt.figure(figsize=figsize)
legend_kw, plot_options = _sep_plot_options(plot_options)
ax_index = 1
axes = []
if show_neighbors:
for roi_name, points, cell_types, labels, neighbors in ab.iter_roi(
fields=['centroid', 'cell_type', 'neighbors'],
filter_rois=rois,
disable_pbar=True,
):
cell_types = _masked_cell_type(cell_types, unique_types,
selected_types, masked_type_name)
# get points
points = np.array(points)
# get neighbors
labels = np.asarray(labels)
nmin = labels.min()
links = []
for l, neigh in zip(labels, neighbors):
for n in neigh:
if int(n) > l:
links.append((n - nmin, l - nmin))
ax = fig.add_subplot(nrow, ncol, ax_index)
point_map(points, types=cell_types,
links=links,
colors=color_mapper,
ax=ax, legend=False,
**plot_options)
ax.set_title(", ".join([str(i) for i in roi_name]))
ax_index += 1
axes.append(ax)
else:
for roi_name, points, cell_types, polygons in ab.iter_roi(
fields=['centroid', 'cell_type', 'shape'],
filter_rois=rois,
disable_pbar=True,
):
cell_types = _masked_cell_type(cell_types, unique_types,
selected_types, masked_type_name)
if use_shape:
ax = fig.add_subplot(nrow, ncol, ax_index)
polygon_map(polygons, types=cell_types,
colors=color_mapper, ax=ax,
legend=False, **plot_options)
else:
points = np.array(points)
if ab.dimension == 2:
ax = fig.add_subplot(nrow, ncol, ax_index)
point_map(points, types=cell_types,
colors=color_mapper, ax=ax,
legend=False, **plot_options)
else:
ax = fig.add_subplot(nrow, ncol, ax_index, projection="3d")
ax = point_map(points, types=cell_types,
colors=color_mapper, ax=ax,
legend=False, **plot_options)
ax.set_title(", ".join([str(i) for i in roi_name]))
ax_index += 1
axes.append(ax)
if ab.has_cell_type:
legend_ax = axes[-1]
labels, colors = zip(*legend_color_mapper.items())
legend_options = dict(
title="Cell Type",
bbox_to_anchor=(1.05, 0.5),
loc="center left",
)
legend_options = {**legend_options, **legend_kw}
CatLegend(colors=colors, labels=labels, handle="circle",
ax=legend_ax, **legend_options)
plt.subplots_adjust(wspace=wspace, hspace=hspace)
plt.close()
return fig
[docs]@doc
def expression_map(
data: AnnData,
rois: List[str],
markers: List[str],
use_shape: bool = False,
x_axis: str = "marker",
figsize: Tuple = None,
wspace: float = 0,
hspace: float = 0.2,
selected_types: List = None,
cell_type_key: str = None,
marker_key: str = None,
shape_key: str = None,
centroid_key: str = None,
roi_key: str = None,
**plot_options,
):
"""Visualize marker expression in ROI
Parameters
----------
data : {adata_plotting}
rois : list of str
A list of ROI name that you want to plot.
markers : list of str
A list of markers name that you want to plot.
x_axis : {'marker', 'roi'}, default: 'marker'
What is on the x-axis, the marker or roi.
use_shape : bool
Plot cell in polygon only when shape data is available.
figsize : tuple of float
The size of figure.
wspace : float, default: 0
The space between plots vertically.
hspace : float, default: 0.1
The space between plots horizontally.
selected_types : {selected_types}
cell_type_key : {cell_type_key}
marker_key : {marker_key}
shape_key : {shape_key}
centroid_key : {centroid_key}
roi_key : {roi_key}
**plot_options :
Pass to :func:`milkviz.point_map` or :func:`milkviz.polygon_map`
"""
ab = AnalysisBase(data,
cell_type_key=cell_type_key,
shape_key=shape_key,
centroid_key=centroid_key,
roi_key=roi_key,
marker_key=marker_key,
verbose=False)
ab.is_rois_name_unique()
if isinstance(rois, str):
rois = [rois]
if isinstance(markers, str):
markers = [markers]
unique_types = ab.cell_types
if ab.has_cell_type & (selected_types is not None):
unique_types = np.unique(selected_types)
nrow = len(rois)
ncol = len(markers)
ax_indexes = np.arange(1, nrow * ncol + 1)
if x_axis == "roi":
nrow, ncol = ncol, nrow
ax_indexes = ax_indexes.reshape(nrow, ncol).T.flatten()
if figsize is None:
figsize = (ncol * 4, nrow * 4)
fig = plt.figure(figsize=figsize)
cbar_kw, plot_options = _sep_plot_options(plot_options, "cbar_kw")
cbar_options = dict(
orientation="horizontal",
loc="upper center",
bbox_to_anchor=(0.5, -0.01)
)
cbar_options = {**cbar_options, **cbar_kw}
ax_indexes_iter = iter(ax_indexes)
axes = []
roi_names = []
for roi_name, points, markers_name, exp, cell_types, polygons in ab.iter_roi(
fields=['centroid', 'exp', 'cell_type', 'shape'],
filter_rois=rois,
disable_pbar=True,
selected_markers=markers
):
roi_names.append(roi_name)
cell_mask = None
if cell_types is not None:
if selected_types is not None:
cell_mask = np.isin(cell_types, unique_types)
exp = exp[:, cell_mask]
if use_shape:
if cell_mask is not None:
polygons = np.asarray(polygons)[cell_mask]
for varray in exp:
ax_index = next(ax_indexes_iter)
ax = fig.add_subplot(nrow, ncol, ax_index)
polygon_map(polygons, values=varray, ax=ax,
cbar_kw=cbar_options, **plot_options)
axes.append(ax)
else:
points = np.array(points)
if cell_mask is not None:
points = points[cell_mask]
if ab.dimension == 2:
for varray in exp:
ax_index = next(ax_indexes_iter)
ax = fig.add_subplot(nrow, ncol, ax_index)
point_map(points, values=varray, ax=ax,
cbar_kw=cbar_options, **plot_options)
axes.append(ax)
else:
for varray in exp:
ax_index = next(ax_indexes_iter)
ax = fig.add_subplot(nrow, ncol, ax_index, projection="3d")
point_map(points, values=varray, ax=ax,
cbar_kw=cbar_options, **plot_options)
axes.append(ax)
# add title
roi_label = [", ".join([str(i) for i in roi_name]) for roi_name in
roi_names]
label_index = np.arange(0, nrow * ncol).reshape(nrow, ncol)
x_index = label_index[0]
y_index = label_index[:, 0]
x_content = markers
y_content = roi_label
if x_axis == "roi":
axes = np.asarray(axes)[np.argsort(ax_indexes)]
x_content, y_content = y_content, x_content
for i, c in zip(x_index, x_content):
axes[i].set_title(c)
for i, c in zip(y_index, y_content):
ax = axes[i]
if ab.dimension == 2:
text = getattr(ax, 'text')
else:
text = getattr(ax, 'text2D')
text(-0.01, 0.5, c, transform=ax.transAxes,
fontdict=dict(rotation=90, va="center", ha="center"))
plt.subplots_adjust(wspace=wspace, hspace=hspace)
plt.close()
return fig