from __future__ import annotations
from collections import Counter
import numpy as np
import pandas as pd
import warnings
from anndata import AnnData
from ast import literal_eval
from functools import cached_property
from natsort import natsorted
from rich.progress import track
from spatialtis_core import reads_wkt_points
from time import time
from typing import Any, Dict, List, Optional, Union, Sequence
from spatialtis.config import Config, console
from spatialtis.utils import df2adata_uns, doc, log_print, pretty_time, read_exp, read_shapes, default_args
class NeighborsNotFoundError(Exception):
pass
class CellTypeNotFoundError(Exception):
pass
class Timer:
"""Base Class for timing"""
_task_name: Optional[str] = None
display_name: str
method: Optional[str] = None
start_time: float
end_time: float
used: str
def start_timer(self) -> None:
log_print(
f":hourglass_not_done: [green]{self.display_name}[/green]", custom=True
)
if self.method is not None:
log_print(f":hammer_and_wrench: Method: {self.method}")
self.start_time = time()
def stop_timer(self) -> None:
self.end_time = time()
self.used = pretty_time(self.end_time - self.start_time)
log_print(f":stopwatch: [bold cyan]{self.used}[/bold cyan]", custom=True)
@property
def task_name(self):
return self._task_name
@task_name.setter
def task_name(self, v: str):
self._task_name = v
self.display_name = " ".join(v.split("_")).capitalize()
def neighbors_pairs(
labels: List[int], neighbors: List[List[int]], duplicates: bool = False
):
p1, p2 = [], []
if duplicates:
for l, ns in zip(labels, neighbors):
for n in ns:
p1.append(l)
p2.append(n)
else:
for l, ns in zip(labels, neighbors):
for n in ns:
if n > l:
p1.append(l)
p2.append(n)
return np.array([p1, p2], dtype=np.int32)
[docs]@doc
class AnalysisBase(Timer):
"""The base class for all analysis function
All parameters apply in this class can be used in analysis
Parameters
----------
data : {adata}
method : str, default: None
The method used in the run of the analysis.
export_key : {export_key}
display_name : str, default: None
The name use to display the name of analysis.
mp : bool, default: Config.mp
Enable parallel processing, no effect since v0.5.0.
exp_obs : {exp_obs}
roi_key : {roi_key}
cell_type_key : {cell_type_key}
centroid_key : {centroid_key}
shape_key : {shape_key}
marker_key : {marker_key}
"""
data: AnnData
exp_obs: List[str]
task_name: str
export_key: str
mp: bool
_result: Optional[pd.DataFrame] = None
method: Optional[str] = None
params: Optional[Dict] = None
verbose: bool = True
roi_key: str
cell_type_key: str
centroid_key: str
marker_key: str
neighbors_key: str = "cell_neighbors"
cell_id_key: str = "cell_id"
area_key: str = "area"
eccentricity_key: str = "eccentricity"
def __repr__(self):
return ""
def __init__(
self,
data: AnnData,
method: Optional[str] = None,
exp_obs: Optional[List[str]] = None,
roi_key: Optional[str] = None,
export_key: Optional[str] = None,
cell_type_key: Optional[str] = None,
centroid_key: Union[str, Sequence[str], None] = None,
shape_key: Optional[str] = None,
marker_key: Optional[str] = None,
mp: Optional[bool] = None,
display_name: Optional[str] = None,
verbose: bool = True
):
self.data = data
self.dimension = 2
self.verbose = verbose
self.task_name = self.__class__.__name__
if display_name is not None:
self.display_name = display_name
self.method = method
self.cell_type_key = default_args(cell_type_key, Config.cell_type_key)
self.centroid_key = default_args(centroid_key, Config.centroid_key)
self.marker_key = default_args(marker_key, Config.marker_key)
self.shape_key = default_args(shape_key, Config.shape_key)
self.mp = default_args(mp, Config.mp)
self.has_cell_type = False
if (self.cell_type_key is not None) & (self.cell_type_key in self.data.obs_keys()):
self.has_cell_type = True
if exp_obs is None:
self.exp_obs = Config.exp_obs
if self.exp_obs is None:
if roi_key is None:
raise ValueError("Please set `Config.exp_obs`/`Config.roi_key` or pass `exp_obs=`/`roi_key=`")
else:
self.exp_obs = [roi_key]
elif isinstance(exp_obs, (str, int, float)):
self.exp_obs = [exp_obs]
else:
self.exp_obs = list(exp_obs)
if roi_key is None:
self.roi_key = self.exp_obs[-1]
else:
if roi_key not in self.exp_obs:
raise ValueError("The `roi_key` is not in your `exp_obs`")
else:
if self.exp_obs[-1] != roi_key:
exp_obs = self.exp_obs
exp_obs.remove(roi_key)
exp_obs.append(roi_key)
self.exp_obs = exp_obs
self.roi_key = roi_key
# assign unique id to each cell, in case of someone cut the data afterwards
# this ensures the analysis still work with non-integrated AnnData
if self.cell_id_key not in data.obs_keys():
data.obs[self.cell_id_key] = [i for i in range(len(data.obs))]
if export_key is None:
self.export_key = self.task_name
else:
self.export_key = export_key
if verbose:
self.start_timer()
@cached_property
def markers(self):
if self.marker_key is not None:
return natsorted(pd.unique(self.data.var[self.marker_key]))
else:
return natsorted(pd.unique(self.data.var.index))
def selected_markers(self, selected_markers=None):
if selected_markers is None:
return self.markers
else:
return natsorted(pd.unique(selected_markers))
@cached_property
def markers_col(self):
if self.marker_key is not None:
return self.data.var[self.marker_key]
else:
return self.data.var.index
@cached_property
def cell_types(self):
if self.has_cell_type:
return natsorted(pd.unique(self.data.obs[self.cell_type_key]))
else:
return []
def _get_wkt_points(self, key):
wkt_strings = self.data.obs[key].tolist()
try:
points = reads_wkt_points(wkt_strings)
except Exception:
raise IOError("If you have two columns, try `centroid_key=('cell_x', 'cell_y'). "
"If you store in one column, the centroid must be in wkt format, "
"try `spatialtis.transform_points`")
return points
def get_centroids(self) -> object | List:
ckey = self.centroid_key
# determine the type of centroid
# by default, read 'spatial' from .obsm
if ckey is None:
if 'spatial' in self.data.obsm_keys():
return self.data.obsm['spatial'].tolist()
if 'centroid' in self.data.obs_keys():
return self._get_wkt_points('centroid')
else:
raise ValueError(
"Spatial information not found, please set `Config.centroid_key` or pass `centroid_key=`.")
if isinstance(ckey, str):
if ckey in self.data.obs_keys():
return self._get_wkt_points(ckey)
if ckey in self.data.obsm_keys():
return self.data.obsm[ckey].tolist()
else:
raise ValueError(f"The centroid key `{ckey}` not found in either `.obsm` or `.obs`")
else:
check = True
for c in ckey:
if c not in self.data.obs_keys():
check = False
if check:
return self.data.obs[list(ckey)].to_numpy().tolist()
else:
raise ValueError(f"The centroid keys `{ckey}` not found in `.obs`")
# def roi_iter(
# self,
# sort: bool = False,
# desc: Optional[str] = None,
# disable_pbar: bool = False,
# ):
# """Iterate through ROI with [roi_name, roi_data]
#
# Args:
# sort: whether to sort the ROI
# desc: the pbar description
# disable_pbar: to disable pbar
#
# """
# disable = disable_pbar if disable_pbar else not Config.progress_bar
#
# for roi_name, roi_data in track(
# self.data.obs.groupby(self.exp_obs, sort=sort),
# description=f"[green]{desc}",
# disable=disable,
# console=console,
# ):
# if len(self.exp_obs) == 1:
# roi_name = [roi_name]
# yield roi_name, roi_data
def iter_roi_data(self,
fields: List[str] = None,
):
iter_data = self.data.obs.copy()
for f in fields:
if f == 'centroid':
points = self.get_centroids()
if len(points[0]) == 3:
self.dimension = 3
iter_data['__spatial_centroid'] = points
if f == 'neighbors':
self.check_neighbors()
iter_data['__cell_neighbors'] = self.data.obsm[self.neighbors_key]
return iter_data
[docs] def iter_roi(self,
fields: List[str] = None,
filter_rois: List[str] = None,
sort: bool = False,
desc: str = None,
disable_pbar: bool = None,
selected_markers: List = None,
layer_key: str = None,
dtype: Any = None,
):
"""A generator to iterate ROI
Parameters
----------
fields : list of str, {'centroid', 'exp', 'neighbors', 'cell_type', 'shape', 'label', 'index'}
What fields to retrieve when iterate ROI.
filter_rois : list of str
The roi to be filtered.
sort : bool
Whether to sort ROI.
desc : str
The description in the progress bar.
disable_pbar : bool
Whether to disable progress bar.
selected_markers : list of str
The list of markers to be selected.
layer_key : str
The layer to use for expression.
dtype :
The datatype.
"""
desc = default_args(desc, self.display_name)
disable_pbar = default_args(disable_pbar, not Config.progress_bar)
fields = default_args(fields, [])
iter_data = self.iter_roi_data(fields)
if 'exp' in fields:
selected_markers = default_args(selected_markers, self.markers)
markers_mask = self.markers_col.isin(selected_markers)
markers = self.markers_col[markers_mask].to_numpy()
if filter_rois is not None:
iter_data = iter_data[iter_data[self.roi_key].isin(filter_rois)].copy()
for roi_name, roi_data in track(
iter_data.groupby(self.exp_obs, sort=sort),
description=f"[green]{desc}",
disable=disable_pbar,
console=console,
):
# pandas will show all categories in groupby even if there is no value
if len(roi_data) == 0:
continue
if len(self.exp_obs) == 1:
roi_name = [roi_name]
yield_fields = [roi_name]
for f in fields:
if f == 'centroid':
yield_fields.append(roi_data['__spatial_centroid'].values.tolist())
elif f == 'exp':
exp = read_exp(self.data[roi_data.index, markers_mask], layer_key=layer_key, dtype=dtype)
yield_fields.append(markers) # ndarray
yield_fields.append(exp) # ndarray
elif f == 'neighbors':
neighbors = [literal_eval(n) for n in roi_data['__cell_neighbors'].values]
labels = roi_data[self.cell_id_key].tolist()
yield_fields.append(labels) # list
yield_fields.append(neighbors) # list
elif f == 'cell_type':
if self.has_cell_type:
yield_fields.append(roi_data[self.cell_type_key].to_numpy())
else:
yield_fields.append(None)
elif f == 'shape':
if self.shape_key is None:
yield_fields.append(None)
else:
yield_fields.append(read_shapes(roi_data, self.shape_key))
elif f == 'label':
labels = roi_data[self.cell_id_key].tolist()
yield_fields.append(labels)
elif f == 'index':
yield_fields.append(roi_data.index)
yield yield_fields
def type_counter(self) -> pd.DataFrame:
self.check_cell_type()
matrix = []
meta = []
for roi_name, cell_types in self.iter_roi(fields=['cell_type'], disable_pbar=True):
c = Counter(cell_types)
matrix.append([c.get(t, 0) for t in self.cell_types])
if isinstance(roi_name, (str, int, float)):
meta.append((roi_name,))
else:
meta.append((*roi_name,))
index = pd.MultiIndex.from_tuples(meta)
index.names = self.exp_obs
return pd.DataFrame(data=matrix, index=index, columns=self.cell_types)
def export_result(self):
export_params = {"exp_obs": self.exp_obs, "method": self.method}
if self.params is not None:
export_params = {**export_params, **self.params}
df2adata_uns(self.result, self.data, self.export_key, params=export_params)
@property
def neighbors_exists(self) -> bool:
return self.neighbors_key in self.data.obsm_keys()
def check_neighbors(self):
if not self.neighbors_exists:
raise NeighborsNotFoundError("Neighbors not found! Run `spatialtis.find_neighbors` first.")
def check_cell_type(self):
if not self.has_cell_type:
raise CellTypeNotFoundError("Cell Type not found! Please set `cell_type_key`")
def is_rois_name_unique(self, warn=True):
key_len = len(self.data.obs[self.roi_key].unique())
obs_len = len([_ for _ in self.iter_roi(disable_pbar=True)])
compare = key_len < obs_len
if compare & warn:
msg = "ROI selection may be incorrect, " \
"ROI number determined by roi_keys " \
"is different from exp_obs, " \
"use `spatialtis.make_roi_unique()` to get unique roi"
warnings.warn(msg)
return not compare
@property
def result(self):
"""Return the result of the analysis"""
return self._result
@result.setter
def result(self, v):
self._result = v
self.export_result()
self.stop_timer()