import logging
from itertools import product
from pathlib import Path
from typing import Optional, Union, List, Tuple
from time import time
import pandas as pd
from anndata import AnnData
from spatialtis.abc import AnalysisBase
from spatialtis.utils import read_neighbors, log_print, doc, pbar_iter
from .preprocess import overlap_genes, train_test_split, neighbors_pairs, \
graph_data_loader, predict_data_loader
MODEL_SAVE_KEY = "SpatialTis-GCNG-Model-State"
[docs]@doc
class GCNG(AnalysisBase):
"""A pytorch reimplementation of GCNG
Use to identify directional gene-gene interactions. The trained model will be automatically save to anndata.
.. note::
To perform this analysis, you need `pytorch <https://pytorch.org/get-started/locally/#start-locally>`_,
`pytorch-geometry <https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html>`_ and
`pytorch-lightning <https://www.pytorchlightning.ai/>`_ installed. If you have GPU, make sure you install
pytorch with GPU support, it would be way more faster than CPU.
Args:
data: {adata}
known_pairs: The input data for training, should be a dataframe with three columns,
ligand, receptor, relationship; 0 means not interact, 1 means interact.
predict_pairs: The pairs that you interested
train_partition: The ratio to split the dataset for training
gpus: Number of gpu to use, can be auto-detected
max_epochs: Number of epoch
lr: Learning rate
batch_size: The batch size
random_seed: The random seed
load_model: To load a pretrained model from anndata
**kwargs: {analysis_kwargs}
"""
def __init__(
self,
data: AnnData,
known_pairs: Optional[pd.DataFrame] = None,
predict_pairs: Optional[List[Tuple]] = None,
train_partition: float = 0.9,
gpus: Optional[int] = None,
max_epochs: int = 10,
lr: float = 1e-4,
batch_size: int = 32,
random_seed: int = 42,
load_model: bool = False,
**kwargs,
):
try:
import torch
import torch.nn.functional as F
from torch.nn import Flatten, Linear
from torch_geometric.nn import GCNConv, global_max_pool
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
except ImportError:
raise ImportError("To run GCNG, please install pytorch, pytorch-lightning, "
"torch-geometric, torch_sparse and torch_scatter.")
if known_pairs is None:
raise NotImplementedError("Currently, you need to supply the training pairs youself")
if predict_pairs is None:
raise ValueError("To run the model, you must specific the `predict_pairs`"
"and tell spatialtis the ligand-receptor pairs you want to predict.")
else:
if len(predict_pairs) < batch_size:
raise ValueError("The predict_pairs must be longer than batch size")
super().__init__(data, display_name="GCNG", **kwargs)
device = "cpu"
if gpus is None:
cuda_count = torch.cuda.device_count()
gpus = cuda_count
if gpus > 0:
device = "cuda"
# To make pytorch a optional deps
# We could only init the model from within
class GCNGModel(LightningModule):
def __init__(self, node_size, output_features, lr=lr):
super().__init__()
self.conv1 = GCNConv(2, 32)
self.conv2 = GCNConv(32, 32)
self.dense1 = Linear(output_features * node_size * 32, 512)
self.dense2 = Linear(512, output_features)
self.flatten = Flatten()
self.lr = lr
self.correct = 0
self.test_data_len = 0
self.acc = 0
self.pred = []
def forward(self, x, edge_index):
# x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.elu(x)
x = self.conv2(x, edge_index)
x = F.elu(x)
x = torch.flatten(x)
x = self.dense1(x)
x = F.elu(x)
x = self.dense2(x)
return torch.sigmoid(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_data, batch_idx):
x, edge_index, batch = train_data.x, train_data.edge_index, train_data.batch
x = self(x, edge_index)
loss_in = x.flatten()
loss_out = train_data.y
loss = F.binary_cross_entropy(loss_in, loss_out)
return loss
def test_step(self, test_data, batch_idx):
x, edge_index, batch = test_data.x, test_data.edge_index, test_data.batch
x = self(x, edge_index)
pred = x.detach().cpu().numpy().flatten().round()
truth_y = test_data.y.cpu().numpy()
self.correct += (pred == truth_y).sum()
self.test_data_len += len(test_data.y)
self.acc = self.correct / self.test_data_len
return self.acc
def predict_step(self, predict_data, batch_idx):
x, edge_index, batch = predict_data.x, predict_data.edge_index, predict_data.batch
x = self(x, edge_index)
self.pred = x.detach().cpu().numpy().flatten().round().tolist()
return self.pred
def release_gpu_mem(self):
try:
torch.cuda.empty_cache()
except Exception:
pass
def on_train_end(self, *args, **kwargs):
self.release_gpu_mem()
def on_predict_batch_end(self, *args, **kwargs):
self.release_gpu_mem()
# init model and trainer first
gc = GCNGModel(data.n_obs, batch_size, lr=lr)
pl.seed_everything(random_seed, workers=True)
trainer = pl.Trainer(gpus=gpus,
max_epochs=max_epochs,
deterministic=True,
progress_bar_refresh_rate=0,
weights_summary=None,
precision=16)
# create neighbors pairs
npairs = neighbors_pairs(data.obs[self.cell_id_key],
read_neighbors(data.obs, self.neighbors_key))
# get exp info and create markers mapper
# markers' name will all be lowercase
exp = data.X.T
markers = data.var.markers
markers = pd.Series([i.lower() for i in markers], index=markers.index)
markers_mapper = dict(zip(markers.tolist(), range(len(markers))))
if load_model: # load pre-trained model
try:
state = self.data.uns[MODEL_SAVE_KEY]
except KeyError:
raise ValueError("Pre-trained model not found, please retrain the model")
gc.load_state_dict(state)
else:
# find overlap genes
lr_genes = known_pairs.iloc[:, [0, 1]].to_numpy().flatten()
overlap_sets = overlap_genes(self.markers, lr_genes)
filtered_pairs = known_pairs[known_pairs.iloc[:, 0].isin(overlap_sets) &
known_pairs.iloc[:, 1].isin(overlap_sets)].iloc[:, [0, 1, 2]]
if len(filtered_pairs) == 0:
raise ValueError("The gene in `known_pairs` has no overlap with genes in data")
# train the model
train, test = train_test_split(filtered_pairs, train_partition)
log_print(f"Training set: {len(train)}, Test set: {len(test)}")
train_loader = graph_data_loader(train, exp, markers_mapper, npairs, device, batch_size, shuffle=True)
test_loader = graph_data_loader(test, exp, markers_mapper, npairs, device, batch_size, shuffle=False)
log_print("Finish loading data, start training")
trainer.fit(gc, train_loader)
trainer.test(dataloaders=test_loader, verbose=False)
log_print(f"Model accuracy {gc.acc}")
self.data.uns[MODEL_SAVE_KEY] = gc.state_dict() # save model
self.model = gc # allow user to access model
self.trainer = trainer
# the model output is dynamically adjust according to batch size
# the predict step should be able to iter through all pairs
predict_size = len(predict_pairs)
append_amount = batch_size - predict_size % batch_size
predict_pairs += predict_pairs[:append_amount]
predict = pd.DataFrame(predict_pairs)
# print(f"predict size {len(predict)}")
# predict_loader = predict_data_loader(predict, exp, markers_mapper, npairs, device, batch_size)
# # init the model and train
# trainer.predict(dataloaders=predict_loader)
# predict['relationship'] = gc.pred
pred = []
for i in pbar_iter(range(0, predict_size, batch_size), desc="Fetching predict result"):
predict_tmp = pd.DataFrame(predict_pairs[i: i+batch_size])
predict_loader = predict_data_loader(predict_tmp, exp, markers_mapper, npairs, device, batch_size)
# init the model and train
trainer.predict(dataloaders=predict_loader)
pred += gc.pred
# release_gpu_mem()
# completely release mem when exit
gc.release_gpu_mem()
predict['relationship'] = pred
predict.columns = ['Gene1', 'Gene2', 'relationship']
self.result = predict.iloc[:predict_size, :].copy()