Source code for spatialmeta.model._moe

# Pytorch
from collections import Counter
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.distributions import kl_divergence as kld


# Third Party
import numpy as np
import scanpy as sc
import scipy

# Built-in
from anndata import AnnData
from scipy.sparse import issparse

from copy import deepcopy
import json
from typing import Callable, Mapping, Union, Iterable, Tuple, Optional, Mapping
import os
import warnings


# Package
from ._primitives import *
from ..util.loss import LossFunction
from ..util.logger import get_tqdm
from ..util._classes import AnnDataSM, AnnDataST, AnnDataJointSMST

from ._model import ConditionalVAE

def get_k_elements(arr: Iterable, k:int):
    return list(map(lambda x: x[k], arr))

def get_last_k_elements(arr: Iterable, k:int):
    return list(map(lambda x: x[k:], arr))

def get_elements(arr: Iterable, a:int, b:int):
    return list(map(lambda x: x[a:a+b], arr))

[docs]class ConditionalVAESTSM(ConditionalVAE): """ This class implements a Conditional Variational Autoencoder with Mixture of Experts (MoE) for vertical and horizontal integration of ST and SM. :param adata: AnnDataJointSMST object containing the spatial multi-omics data. :param hidden_stacks: List of integers specifying the number of hidden units in each stack of the encoder and decoder, default is [128]. :param batch_keys: Optional list of strings specifying the batch keys for batch correction. :param n_latent: Integer specifying the dimensionality of the latent space, default is 10. :param bias: Boolean indicating whether to include bias terms in the linear layers, default is True. :param use_batch_norm: Boolean indicating whether to use batch normalization in the linear layers, default is True. :param use_layer_norm: Boolean indicating whether to use layer normalization in the linear layers, default is False. :param dropout_rate: Float specifying the dropout rate for the linear layers, default is 0.1. :param activation_fn: Callable specifying the activation function to use in the linear layers, default is nn.ReLU. :param device: String or torch.device specifying the device to use for computation, default is "cpu". :param batch_embedding: Literal["embedding", "onehot"] specifying the type of batch embedding to use, default is "onehot". :param encode_libsize: Boolean indicating whether to encode library size information, default is False. :param batch_hidden_dim: Integer specifying the dimensionality of the batch hidden layer, default is 8. :param reconstruction_method_st: Literal['mse', 'zg', 'zinb'] specifying the reconstruction method for the spatial data, default is 'zinb'. :param reconstruction_method_sm: Literal['mse', 'zg', 'g'] specifying the reconstruction method for the single-cell multi-omics data, default is 'g'. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.v = nn.Parameter( torch.rand(self.n_hidden,device=self.device) ) self.to_k = nn.Sequential( nn.Linear(self.n_hidden, self.n_hidden), nn.Tanh() ).to(self.device) self.decoder_st = FCLayer( in_dim = self.n_latent, out_dim = self.n_hidden, n_cat_list = self.n_batch_keys, cat_dim = self.batch_hidden_dim, cat_embedding = self.batch_embedding, use_layer_norm=False, use_batch_norm=True, dropout_rate=0, device=self.device ) self.decoder_sm = FCLayer( in_dim = self.n_latent, out_dim = self.n_hidden, n_cat_list = self.n_batch_keys, cat_dim = self.batch_hidden_dim, cat_embedding = self.batch_embedding, use_layer_norm=False, use_batch_norm=True, dropout_rate=0, device=self.device ) self.to(self.device) def get_latent_from_z(self, z_st, z_sm): z = 0.5 * (z_st + z_sm) return z def encode( self, X: torch.Tensor, eps: float = 1e-4 ) -> Tuple[torch.Tensor, torch.Tensor]: X_SM = X[:, self._type == "SM"] X_ST = X[:, self._type == "ST"] X_ST = torch.log(X_ST + 1) q_sm = self.encoder_SM.encode(X_SM) q_st = self.encoder_ST.encode(X_ST) q = torch.hstack((q_sm,q_st)) mu_sm = self.z_mean_fc_single(q_sm) mu_st = self.z_mean_fc_single(q_st) q_mu = self.z_mean_fc(q) #var_sm = torch.exp(self.z_var_fc_single(q_sm)) + eps #var_st = torch.exp(self.z_var_fc_single(q_st)) +eps q_var = torch.exp(self.z_var_fc(q)) + eps #z_sm = Normal(mu_sm, var_sm.sqrt()).rsample() #z_st = Normal(mu_st, var_st.sqrt()).rsample() z = Normal(q_mu, q_var.sqrt()).rsample() H = dict( st = dict( q = q_st, q_mu = mu_st, #q_var = var_st, #z = z_st ), sm = dict( q = q_sm, q_mu = mu_sm, #q_var = var_sm, #z = z_sm ), q = q, q_mu = q_mu, q_var = q_var, z = z ) return H def decode(self, H: Mapping[str, torch.tensor], lib_size: torch.tensor, batch_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: z_st = H["st"]["q_mu"] z_sm = H["sm"]["q_mu"] z = H["z"] if batch_index is not None: z_st = torch.hstack([ z_st, batch_index ]) z_sm = torch.hstack([ z_sm, batch_index ]) z = torch.hstack([ z, batch_index ]) R = [] px_st = self.decoder_st(z.to(self.device)) px_sm = self.decoder_sm(z.to(self.device)) px_st_corr = self.decoder_st(z_st.to(self.device)) px_sm_corr = self.decoder_sm(z_sm.to(self.device)) px_rna_scale = self.px_rna_scale_decoder(px_st) px_rna_rate = self.px_rna_rate_decoder(px_st) px_rna_dropout = self.px_rna_dropout_decoder(px_st) ## In logits px_rna_scale = px_rna_scale * lib_size.unsqueeze(1) px_sm_scale = self.px_sm_scale_decoder(px_sm) #px_sm_scale = F.softmax(px_sm_scale, dim=1) * 1e4 px_sm_rate = self.px_sm_rate_decoder(px_sm) px_sm_dropout = self.px_sm_dropout_decoder(px_sm) ## In logits px_rna_corr_scale = self.px_rna_scale_decoder(px_st_corr) px_rna_corr_rate = self.px_rna_rate_decoder(px_st_corr) px_rna_corr_dropout = self.px_rna_dropout_decoder(px_st_corr) ## In logits px_rna_corr_scale = px_rna_corr_scale * lib_size.unsqueeze(1) px_sm_corr_scale = self.px_sm_scale_decoder(px_sm_corr) #px_sm_corr_scale = F.softmax(px_sm_corr_scale, dim=1) * 1e4 px_sm_corr_rate = self.px_sm_rate_decoder(px_sm_corr) px_sm_corr_dropout = self.px_sm_dropout_decoder(px_sm_corr) ## In logits R = dict( latent = dict( px_st = px_st, px_sm = px_sm, px_rna_scale = px_rna_scale, px_rna_rate = px_rna_rate, px_rna_dropout = px_rna_dropout, px_sm_scale = px_sm_scale, px_sm_rate = px_sm_rate, px_sm_dropout = px_sm_dropout ), corr = dict( px_st = px_st_corr, px_sm = px_sm_corr, px_rna_scale = px_rna_corr_scale, px_rna_rate = px_rna_corr_rate, px_rna_dropout = px_rna_corr_dropout, px_sm_scale = px_sm_corr_scale, px_sm_rate = px_sm_corr_rate, px_sm_dropout = px_sm_corr_dropout ) ) return R def forward( self, X: torch.Tensor, batch_index: Optional[torch.Tensor] = None, reduction: str = "sum", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: H=self.encode(X) mu_st = H["st"]["q_mu"] #var_st = H["st"]["q_var"] mu_sm = H["sm"]["q_mu"] #var_sm = H["sm"]["q_var"] mean_st = torch.zeros_like(mu_st) #scale_st = torch.ones_like(var_st) mean_sm = torch.zeros_like(mu_sm) #scale_sm = torch.ones_like(var_sm) q_mu = H["q_mu"] q_var = H["q_var"] mean = torch.zeros_like(q_mu) scale = torch.ones_like(q_var) if batch_index is not None: mmd_loss = LossFunction.mmd_loss( z = H['q_mu'], cat = batch_index.detach().cpu().numpy(), dim=1, ) else: mmd_loss = torch.tensor(0.0, device=self.device) kldiv_loss = kld(Normal(q_mu, q_var.sqrt()),Normal(mean, scale)).sum(dim = 1) #kldiv_loss_st = kld(Normal(mu_st, var_st.sqrt()),Normal(mean_st, scale_st)).sum(dim = 1) #kldiv_loss_sm = kld(Normal(mu_sm, var_sm.sqrt()),Normal(mean_sm, scale_sm)).sum(dim = 1) X_SM = X[:,self._type=="SM"] X_ST = X[:,self._type=="ST"] Rs=self.decode(H, X_ST.sum(1), batch_index) R_latent = Rs['latent'] R_corr = Rs['corr'] reconstruction_loss_st = 0 reconstruction_loss_sm = 0 reconstruction_loss_st_corr = 0 reconstruction_loss_sm_corr = 0 if self.reconstruction_method_st == 'zinb': reconstruction_loss_st = LossFunction.zinb_reconstruction_loss( X_ST, mu = R_latent['px_rna_scale'], theta = R_latent['px_rna_rate'].exp(), gate_logits = R_latent['px_rna_dropout'], reduction = reduction ) reconstruction_loss_st_corr = LossFunction.zinb_reconstruction_loss( X_ST, mu = R_corr['px_rna_scale'], theta = R_corr['px_rna_rate'].exp(), gate_logits = R_corr['px_rna_dropout'], reduction = reduction ) elif self.reconstruction_method_st == 'zg': reconstruction_loss_st = LossFunction.zi_gaussian_reconstruction_loss( X_ST, mean=R_latent['px_rna_scale'], variance=R_latent['px_rna_rate'].exp(), gate_logits=R_latent['px_rna_dropout'], reduction=reduction ) reconstruction_loss_st_corr = LossFunction.zi_gaussian_reconstruction_loss( X_ST, mean=R_corr['px_rna_scale'], variance=R_corr['px_rna_rate'].exp(), gate_logits=R_corr['px_rna_dropout'], reduction=reduction ) elif self.reconstruction_method_st == 'mse': reconstruction_loss_st = nn.functional.mse_loss( R_latent['px_rna_scale'], X_ST, reduction=reduction ) reconstruction_loss_st_corr = nn.functional.mse_loss( R_corr['px_rna_scale'], X_ST, reduction=reduction ) if self.reconstruction_method_sm == 'zg': reconstruction_loss_sm = LossFunction.zi_gaussian_reconstruction_loss( X_SM, mean = R_latent['px_sm_scale'], variance = R_latent['px_sm_rate'].exp(), gate_logits = R_latent['px_sm_dropout'], reduction = reduction ) reconstruction_loss_sm_corr = LossFunction.zi_gaussian_reconstruction_loss( X_SM, mean = R_corr['px_sm_scale'], variance = R_corr['px_sm_rate'].exp(), gate_logits = R_corr['px_sm_dropout'], reduction = reduction ) elif self.reconstruction_method_sm == 'mse': reconstruction_loss_sm = nn.MSELoss()( R_latent['px_sm_scale'], X_SM, ) reconstruction_loss_sm_corr = nn.MSELoss()( R_corr['px_sm_scale'], X_SM, ) elif self.reconstruction_method_sm == "g": reconstruction_loss_sm = LossFunction.gaussian_reconstruction_loss( X_SM, mean = R_latent['px_sm_scale'], variance = R_latent['px_sm_rate'].exp(), reduction = reduction ) reconstruction_loss_sm_corr = LossFunction.gaussian_reconstruction_loss( X_SM, mean = R_corr['px_sm_scale'], variance = R_corr['px_sm_rate'].exp(), reduction = reduction ) loss_record = { "reconstruction_loss_sm": reconstruction_loss_sm, "reconstruction_loss_st": reconstruction_loss_st, "reconstruction_loss_sm_corr": reconstruction_loss_sm_corr, "reconstruction_loss_st_corr": reconstruction_loss_st_corr, "kldiv_loss": kldiv_loss, #"kldiv_loss_st":kldiv_loss_st, #"kldiv_loss_sm": kldiv_loss_sm, "mmd_loss": mmd_loss } return H, Rs, loss_record def fit( self, max_epoch:int = 35, n_per_batch:int = 128, mode: Optional[Literal['single','multi']] = None, **kwargs ): """ Fits the model. :param max_epoch: Integer specifying the maximum number of epochs to train the model, default is 35. :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param mode: Optional string specifying the mode of training. Can be either 'single' or 'multi', default is None. :param reconstruction_reduction: String specifying the reduction method for the reconstruction loss, default is 'sum'. :param kl_weight: Float specifying the weight of the KL divergence loss, default is 1. :param reconstruction_st_weight: Float specifying the weight of the reconstruction loss for spatial transcriptomics, default is 1. :param reconstruction_sm_weight: Float specifying the weight of the reconstruction loss for single-cell multi-omics, default is 1. :param reconstruction_st_corr_weight: Float specifying the weight of the correlation reconstruction loss for spatial transcriptomics, default is 1. :param reconstruction_sm_corr_weight: Float specifying the weight of the correlation reconstruction loss for single-cell multi-omics, default is 1. :param n_epochs_kl_warmup: Integer specifying the number of epochs for KL divergence warmup, default is 400. :param optimizer_parameters: Iterable specifying the parameters for the optimizer, default is None. :param weight_decay: Float specifying the weight decay for the optimizer, default is 1e-6. :param lr: Float specifying the learning rate for the optimizer. :param random_seed: Integer specifying the random seed, default is 12. :param kl_loss_reduction: String specifying the reduction method for the KL divergence loss, default is 'mean'. :param mmd_weight: Float specifying the weight of the MMD loss, default is 1. :return: Dictionary containing the training loss values. """ if mode == 'single': kwargs['reconstruction_st_weight'] = 5 kwargs['reconstruction_sm_weight'] = 1 kwargs['reconstruction_st_corr_weight'] = 5 kwargs['reconstruction_sm_corr_weight'] = 1 kwargs['kl_weight'] = 0.5 elif mode == 'multi': kwargs['reconstruction_st_weight'] = 8 kwargs['reconstruction_sm_weight'] = 2 kwargs['reconstruction_st_corr_weight'] = 8 kwargs['reconstruction_sm_corr_weight'] = 2 kwargs['kl_weight'] = 1 kwargs['mmd_weight'] = 10 return self.fit_core( max_epoch=max_epoch, n_per_batch=n_per_batch, **kwargs ) def fit_core(self, max_epoch:int = 35, n_per_batch:int = 128, reconstruction_reduction: str = 'sum', kl_weight: float = 1., reconstruction_st_weight: float = 1., reconstruction_sm_weight: float = 1., reconstruction_st_corr_weight: float = 1., reconstruction_sm_corr_weight: float = 1., n_epochs_kl_warmup: Union[int, None] = 400, optimizer_parameters: Iterable = None, weight_decay: float = 1e-6, lr: bool = 5e-5, random_seed: int = 12, kl_loss_reduction: str = 'mean', mmd_weight: float = 1., ): self.train() if n_epochs_kl_warmup: n_epochs_kl_warmup = min(max_epoch, n_epochs_kl_warmup) kl_warmup_gradient = kl_weight / n_epochs_kl_warmup kl_weight_max = kl_weight kl_weight = 0. if optimizer_parameters is None: optimizer = optim.AdamW(self.parameters(), lr, weight_decay=weight_decay) else: optimizer = optim.AdamW(optimizer_parameters, lr, weight_decay=weight_decay) pbar = get_tqdm()(range(max_epoch), desc="Epoch", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') loss_record = { "reconstruction_loss_sm": 0, "reconstruction_loss_st": 0, "reconstruction_loss_sm_corr": 0, "reconstruction_loss_st_corr": 0, "kldiv_loss": 0, #"kldiv_loss_st": 0, #"kldiv_loss_sm": 0, "total_loss_sm": 0, "total_loss_st": 0, "mmd_loss": 0 } epoch_reconstruction_loss_st_list = [] epoch_reconstruction_loss_sm_list = [] epoch_reconstruction_loss_st_corr_list = [] epoch_reconstruction_loss_sm_corr_list = [] epoch_kldiv_loss_list = [] #epoch_kldiv_loss_st_list = [] #epoch_kldiv_loss_sm_list = [] epoch_total_loss_list = [] epoch_mmd_loss_list = [] #epoch_sm_gate_logits_list = [] for epoch in range(1, max_epoch+1): self._trained = True pbar.desc = "Epoch {}".format(epoch) epoch_total_loss = 0 epoch_reconstruction_loss_sm = 0 epoch_reconstruction_loss_st = 0 epoch_reconstruction_loss_sm_corr = 0 epoch_reconstruction_loss_st_corr = 0 epoch_kldiv_loss = 0 #epoch_kldiv_loss_st = 0 #epoch_kldiv_loss_sm = 0 epoch_mmd_loss = 0 #epoch_sm_gate_logits = [] X_train = self.as_dataloader( batch_size=n_per_batch, shuffle=True ) for batch_idx in X_train: indices = batch_idx[0].cpu().numpy() X_batch = [] for idx in indices: if scipy.sparse.issparse(self.X): x_row = self.X.getrow(idx).toarray().squeeze() else: x_row = self.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H, Rs, L = self.forward( X_batch, batch_index=batch_index, reduction=reconstruction_reduction, ) #for R in Rs: # epoch_sm_gate_logits.append( # R['px_sm_dropout'].detach().cpu().numpy() # ) reconstruction_loss_st = L['reconstruction_loss_st'] reconstruction_loss_sm = L['reconstruction_loss_sm'] reconstruction_loss_st_corr = L['reconstruction_loss_st_corr'] reconstruction_loss_sm_corr = L['reconstruction_loss_sm_corr'] kldiv_loss = L['kldiv_loss'] #kldiv_loss_st = L['kldiv_loss_st'] #kldiv_loss_sm = L['kldiv_loss_sm'] mmd_loss = L['mmd_loss'] #loss = 1*reconstruction_loss_sm.mean() + 0.5*reconstruction_loss_st.mean() + kldiv_loss.mean() avg_reconstruction_loss_st = reconstruction_loss_st.mean() / n_per_batch avg_reconstruction_loss_sm = reconstruction_loss_sm.mean() / n_per_batch avg_reconstruction_loss_st_corr = reconstruction_loss_st_corr.mean() / n_per_batch avg_reconstruction_loss_sm_corr = reconstruction_loss_sm_corr.mean() / n_per_batch avg_mmd_loss = mmd_loss.mean() / n_per_batch if kl_loss_reduction == 'mean': avg_kldiv_loss = kldiv_loss.mean() / n_per_batch #avg_kldiv_loss_st = kldiv_loss_st.mean() / n_per_batch #avg_kldiv_loss_sm = kldiv_loss_sm.mean() / n_per_batch elif kl_loss_reduction == 'sum': avg_kldiv_loss = kldiv_loss.sum() / n_per_batch #avg_kldiv_loss_st = kldiv_loss_st.sum() / n_per_batch #avg_kldiv_loss_sm = kldiv_loss_sm.sum() / n_per_batch loss = (avg_reconstruction_loss_sm * reconstruction_sm_weight) + \ (avg_reconstruction_loss_st * reconstruction_st_weight) + \ (avg_reconstruction_loss_sm_corr * reconstruction_sm_corr_weight) + \ (avg_reconstruction_loss_st_corr * reconstruction_st_corr_weight) + \ (avg_kldiv_loss * kl_weight) + \ (avg_mmd_loss * mmd_weight) #+ \ #(avg_kldiv_loss_sm * kl_weight_sm) + \ #(avg_kldiv_loss_st * kl_weight_st) #loss = avg_reconstruction_loss_sm + avg_reconstruction_loss_st + (avg_kldiv_loss * kl_weight) + avg_mmd_loss epoch_reconstruction_loss_sm += avg_reconstruction_loss_sm.item() epoch_reconstruction_loss_st += avg_reconstruction_loss_st.item() epoch_reconstruction_loss_sm_corr += avg_reconstruction_loss_sm_corr.item() epoch_reconstruction_loss_st_corr += avg_reconstruction_loss_st_corr.item() epoch_mmd_loss += avg_mmd_loss.item() epoch_kldiv_loss += avg_kldiv_loss.item() #epoch_kldiv_loss_st += avg_kldiv_loss_st.item() #epoch_kldiv_loss_sm += avg_kldiv_loss_sm.item() epoch_total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix({ 'reconst_sm': '{:.2e}'.format(epoch_reconstruction_loss_sm), 'reconst_st': '{:.2e}'.format(epoch_reconstruction_loss_st), 'reconst_sm_corr': '{:.2e}'.format(epoch_reconstruction_loss_sm_corr), 'reconst_st_corr': '{:.2e}'.format(epoch_reconstruction_loss_st_corr), 'kldiv': '{:.2e}'.format(epoch_kldiv_loss), #'kldiv_st': '{:.2e}'.format(epoch_kldiv_loss_st), #'kldiv_sm': '{:.2e}'.format(epoch_kldiv_loss_sm), 'total_loss': '{:.2e}'.format(epoch_total_loss), 'mmd_loss': '{:.2e}'.format(epoch_mmd_loss) }) pbar.update(1) epoch_reconstruction_loss_sm_list.append(epoch_reconstruction_loss_sm) epoch_reconstruction_loss_st_list.append(epoch_reconstruction_loss_st) epoch_reconstruction_loss_sm_corr_list.append(epoch_reconstruction_loss_sm_corr) epoch_reconstruction_loss_st_corr_list.append(epoch_reconstruction_loss_st_corr) epoch_kldiv_loss_list.append(epoch_kldiv_loss) #epoch_kldiv_loss_st_list.append(epoch_kldiv_loss_st) #epoch_kldiv_loss_sm_list.append(epoch_kldiv_loss_sm) epoch_total_loss_list.append(epoch_total_loss) epoch_mmd_loss_list.append(epoch_mmd_loss) #epoch_sm_gate_logits = np.vstack(epoch_sm_gate_logits) #epoch_sm_gate_logits_list.append(epoch_sm_gate_logits) if n_epochs_kl_warmup: kl_weight = min( kl_weight + kl_warmup_gradient, kl_weight_max) random_seed += 1 pbar.close() self.trained_state_dict = deepcopy(self.state_dict()) return dict( epoch_reconstruction_loss_st_list=epoch_reconstruction_loss_st_list, epoch_reconstruction_loss_sm_list=epoch_reconstruction_loss_sm_list, epoch_reconstruction_loss_st_corr_list=epoch_reconstruction_loss_st_corr_list, epoch_reconstruction_loss_sm_corr_list=epoch_reconstruction_loss_sm_corr_list, epoch_kldiv_loss_list=epoch_kldiv_loss_list, #epoch_kldiv_loss_st_list=epoch_kldiv_loss_st_list, #epoch_kldiv_loss_sm_list=epoch_kldiv_loss_sm_list, #epoch_sm_gate_logits_list=epoch_sm_gate_logits_list, epoch_total_loss_list=epoch_total_loss_list, epoch_mmd_loss_list=epoch_mmd_loss_list ) @torch.no_grad() def get_latent_embedding( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True, ) -> np.ndarray: """ Get the latent embedding of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the latent embedding. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] Zs_st = [] Zs_sm = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch_idx in dataloader: indices = batch_idx[0].cpu().numpy() # from X obtain the batch data X_batch = [] for idx in indices: if scipy.sparse.issparse(self.X): x_row = self.X.getrow(idx).toarray().squeeze() else: x_row = self.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # deal with batch codes if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H = self.encode(X_batch) Zs.append(H[latent_key].detach().cpu().numpy()) Zs_st.append(H['st'][latent_key].detach().cpu().numpy()) Zs_sm.append(H['sm'][latent_key].detach().cpu().numpy()) if show_progress: pbar.update(1) if show_progress: pbar.close() return np.vstack(Zs)#, np.vstack(Zs_st)[self._shuffle_indices], np.vstack(Zs_sm)[self._shuffle_indices] @torch.no_grad() def _get_latent_embedding( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True, ) -> np.ndarray: """ Get the latent embedding of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the latent embedding. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] Zs_st = [] Zs_sm = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch_idx in dataloader: indices = batch_idx[0].cpu().numpy() # from X obtain the batch data X_batch = [] for idx in indices: if scipy.sparse.issparse(self.adata.X): x_row = self.adata.X.getrow(idx).toarray().squeeze() else: x_row = self.adata.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # deal with batch codes if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H = self.encode(X_batch) Zs.append(H[latent_key].detach().cpu().numpy()) Zs_st.append(H['st'][latent_key].detach().cpu().numpy()) Zs_sm.append(H['sm'][latent_key].detach().cpu().numpy()) if show_progress: pbar.update(1) if show_progress: pbar.close() Zs_all = np.vstack(Zs) Zs_st_all = np.vstack(Zs_st) Zs_sm_all = np.vstack(Zs_sm) if hasattr(self, '_shuffle_indices'): Zs_all = Zs_all[self._shuffle_indices] Zs_st_all = Zs_st_all[self._shuffle_indices] Zs_sm_all = Zs_sm_all[self._shuffle_indices] return Zs_all, Zs_st_all, Zs_sm_all @torch.no_grad() def get_normalized_expression( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True ) -> np.ndarray: """ Get the normalized expression of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the normalized expression. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] Zs_rate = [] Zs_dropout = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Normalized Expression", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch_idx in dataloader: indices = batch_idx[0].cpu().numpy() # Read batch from adata.X X_batch = [] for idx in indices: if scipy.sparse.issparse(self.adata.X): x_row = self.adata.X.getrow(idx).toarray().squeeze() else: x_row = self.adata.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # Handle batch codes if applicable if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H,Rs,_ = self.forward(X_batch, batch_index=batch_index) Zs.append( np.hstack([ Rs['latent']['px_sm_scale'].detach().cpu().numpy(), Rs['latent']['px_rna_scale'].detach().cpu().numpy() ]) ) Zs_rate.append( np.hstack([ Rs['latent']['px_sm_rate'].detach().cpu().numpy(), Rs['latent']['px_rna_rate'].detach().cpu().numpy() ]) ) Zs_dropout.append( np.hstack([ Rs['latent']['px_sm_dropout'].detach().cpu().numpy(), Rs['latent']['px_rna_dropout'].detach().cpu().numpy() ]) ) if show_progress: pbar.update(1) if show_progress: pbar.close() Zs_all = np.vstack(Zs) Zs_rate_all = np.vstack(Zs_rate) Zs_dropout_all = np.vstack(Zs_dropout) if hasattr(self, '_shuffle_indices'): Zs_all = Zs_all[self._shuffle_indices] Zs_rate_all = Zs_rate_all[self._shuffle_indices] Zs_dropout_all = Zs_dropout_all[self._shuffle_indices] return Zs_all#, Zs_rate_all, Zs_dropout_all @torch.no_grad() def get_normalized_expression_corr( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True ) -> np.ndarray: """ Get the normalized expression of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the normalized expression. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] Zs_rate = [] Zs_dropout = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Normalized Expression Corr", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch in dataloader: indices = batch[0].cpu().numpy() # Read batch from adata.X X_batch = [] for idx in indices: if scipy.sparse.issparse(self.adata.X): x_row = self.adata.X.getrow(idx).toarray().squeeze() else: x_row = self.adata.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # Handle batch codes if applicable if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H,Rs,_ = self.forward(X_batch, batch_index=batch_index) Zs.append( np.hstack([ Rs['corr']['px_sm_scale'].detach().cpu().numpy(), Rs['corr']['px_rna_scale'].detach().cpu().numpy() ]) ) Zs_rate.append( np.hstack([ Rs['corr']['px_sm_rate'].detach().cpu().numpy(), Rs['corr']['px_rna_rate'].detach().cpu().numpy() ]) ) Zs_dropout.append( np.hstack([ Rs['corr']['px_sm_dropout'].detach().cpu().numpy(), Rs['corr']['px_rna_dropout'].detach().cpu().numpy() ]) ) if show_progress: pbar.update(1) if show_progress: pbar.close() Zs_all = np.vstack(Zs) Zs_rate_all = np.vstack(Zs_rate) Zs_dropout_all = np.vstack(Zs_dropout) if hasattr(self, '_shuffle_indices'): Zs_all = Zs_all[self._shuffle_indices] Zs_rate_all = Zs_rate_all[self._shuffle_indices] Zs_dropout_all = Zs_dropout_all[self._shuffle_indices] return Zs_all#, Zs_rate_all, Zs_dropout_all def get_modality_contribution( self, latent_key: Literal["z", "q_mu"] = "q_mu", ): """ Get the contribution of each modality to the joint latent space. :param latent_key: Which latent representation to use, either "z" for the sampled latent or "q_mu" for the mean of the latent distribution """ method = "cos" def angular_similarity(x, y): dot = np.dot(x, y) norms = np.linalg.norm(x) * np.linalg.norm(y) cos_similarity = dot / norms angular = (1 - np.arccos(cos_similarity) / np.pi) return angular joint_latent, st_latent, sm_latent = self._get_latent_embedding( latent_key = latent_key ) st_latent = np.vstack(st_latent) st_latent = sc.AnnData(X=st_latent) sm_latent = np.vstack(sm_latent) sm_latent = sc.AnnData(X=sm_latent) joint_latent = np.vstack(joint_latent) joint_latent = sc.AnnData(X=joint_latent) if method == "cos": ang_sm = np.array([angular_similarity(x, y) for x, y in zip(joint_latent.X, sm_latent.X)]) ang_st = np.array([angular_similarity(x, y) for x, y in zip(joint_latent.X, st_latent.X)]) contribution_st_sm = ang_st - ang_sm + 0.5 return contribution_st_sm