Source code for spatialmeta.model._model

# Pytorch
from collections import Counter
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.distributions import kl_divergence as kld


# Third Party

import numpy as np


# Built-in
import numpy as np

from anndata import AnnData
from scipy.sparse import issparse
import scipy.sparse

from copy import deepcopy
import json
from typing import Callable, Mapping, Union, Iterable, Tuple, Optional, Mapping
import os
import warnings


# Package
from ._primitives import *
from ..util.loss import LossFunction
from ..util.logger import get_tqdm
from ..util._classes import AnnDataSM, AnnDataST, AnnDataJointSMST

def get_k_elements(arr: Iterable, k:int):
    return list(map(lambda x: x[k], arr))

def get_last_k_elements(arr: Iterable, k:int):
    return list(map(lambda x: x[k:], arr))

def get_elements(arr: Iterable, a:int, b:int):
    return list(map(lambda x: x[a:a+b], arr))

class ConditionalVAE(nn.Module):
    """
    This class implements a Conditional Variational Autoencoder (CVAE) for vertical and horizontal integration of ST and SM.
    
    :param adata: AnnDataJointSMST object containing the spatial multi-omics data.
    :param hidden_stacks: List of integers specifying the number of hidden units in each stack of the encoder and decoder, default is [128].
    :param batch_keys: Optional list of strings specifying the batch keys for batch correction.
    :param n_latent: Integer specifying the dimensionality of the latent space, default is 10.
    :param bias: Boolean indicating whether to include bias terms in the linear layers, default is True.
    :param use_batch_norm: Boolean indicating whether to use batch normalization in the linear layers, default is True.
    :param use_layer_norm: Boolean indicating whether to use layer normalization in the linear layers, default is False.
    :param dropout_rate: Float specifying the dropout rate for the linear layers, default is 0.1.
    :param activation_fn: Callable specifying the activation function to use in the linear layers, default is nn.ReLU.
    :param device: String or torch.device specifying the device to use for computation, default is "cpu".
    :param batch_embedding: Literal["embedding", "onehot"] specifying the type of batch embedding to use, default is "onehot".
    :param encode_libsize: Boolean indicating whether to encode library size information, default is False.
    :param batch_hidden_dim: Integer specifying the dimensionality of the batch hidden layer, default is 8.
    :param reconstruction_method_st: Literal['mse', 'zg', 'zinb'] specifying the reconstruction method for the spatial data, default is 'zinb'. mse is mean squared error, zg is zero-inflated Gaussian, and zinb is zero-inflated negative binomial.
    :param reconstruction_method_sm: Literal['mse', 'zg', 'g'] specifying the reconstruction method for the single-cell multi-omics data, default is 'g'. mse is mean squared error, zg is zero-inflated Gaussian, and g is Gaussian.
    
    """
    def __init__(
        self,
        adata: AnnDataJointSMST,
        hidden_stacks: List[int] = [128], 
        batch_keys: Optional[List[str]] = None,
        n_latent: int = 10,
        bias: bool = True,
        use_batch_norm: bool = True,
        use_layer_norm: bool = False,
        dropout_rate: float = 0.1,
        activation_fn: Callable = nn.ReLU,
        device: Union[str, torch.device] = "cpu",
        batch_embedding: Literal["embedding", "onehot"] = "onehot",
        encode_libsize: bool = False,
        batch_hidden_dim: int = 8,
        reconstruction_method_st: Literal['mse', 'zg', 'zinb'] = 'zinb',
        reconstruction_method_sm: Literal['mse', 'zg', 'g'] = 'g'
    ):
        super(ConditionalVAE, self).__init__()

        self.adata = adata 

        self.hidden_stacks = hidden_stacks
        self.n_hidden = hidden_stacks[-1]
        self.n_latent = n_latent
        self.device = device
        self.reconstruction_method_st = reconstruction_method_st
        self.reconstruction_method_sm = reconstruction_method_sm
        self.encode_libsize = encode_libsize
        
        self.batch_hidden_dim = batch_hidden_dim
        self.batch_embedding = batch_embedding
        
        self.batch_keys = [batch_keys] if isinstance(batch_keys, str) else batch_keys

        self.initialize_dataset()

        self.fcargs = dict(
            bias           = bias, 
            dropout_rate   = dropout_rate, 
            use_batch_norm = use_batch_norm, 
            use_layer_norm = use_layer_norm,
            activation_fn  = activation_fn,
            device         = device
        )
        
        
        self.encoder_ST = SAE(
            self.in_dim_ST if not self.encode_libsize else self.in_dim_ST + 1,
            stacks = hidden_stacks,
            encode_only = True,
            **self.fcargs
        )  
        
        self.encoder_SM = SAE(
            self.in_dim_SM,
            stacks = hidden_stacks,
            encode_only = True,
            **self.fcargs
        )
            
        self.decoder = FCLayer(
            in_dim = self.n_latent, 
            out_dim = self.n_hidden,
            n_cat_list = self.n_batch_keys,
            cat_dim = batch_hidden_dim,
            cat_embedding = batch_embedding,
            use_layer_norm=False,
            use_batch_norm=True,
            dropout_rate=0,
            device=device
        ) 
        
        self.encode_libsize = encode_libsize
        
        # The latent cell representation z ~ Logisticnormal(0, I)
        self.z_mean_fc = nn.Linear(self.n_hidden*2, self.n_latent)
        self.z_var_fc = nn.Linear(self.n_hidden*2, self.n_latent)
        self.z_mean_fc_single = nn.Linear(self.n_hidden, self.n_latent)
        self.z_var_fc_single = nn.Linear(self.n_hidden, self.n_latent)

        self.px_rna_rate_decoder = nn.Linear(
            self.n_hidden, 
            self.in_dim_ST
        )
        
        self.px_rna_scale_decoder = nn.Sequential(
            nn.Linear(self.n_hidden, self.in_dim_ST),
            nn.Softmax(dim=-1)
        )
        
        self.px_rna_dropout_decoder = nn.Linear(
            self.n_hidden, 
            self.in_dim_ST
        )
        
        self.px_sm_rate_decoder = nn.Linear(
            self.n_hidden, 
            self.in_dim_SM
        )
        
        #self.px_sm_scale_decoder = nn.Sequential(
        #    nn.Linear(self.n_hidden, self.in_dim_SM),
        #    nn.ReLU()
        #)
        
        self.px_sm_scale_decoder = nn.Linear(self.n_hidden, self.in_dim_SM)
        
        self.px_sm_dropout_decoder = nn.Linear(
            self.n_hidden,
            self.in_dim_SM
        )
        
        self.to(self.device)
        
    def as_dataloader(self, batch_size=32, shuffle=True):
        index_dataset = torch.utils.data.TensorDataset(
            torch.tensor(self._indices, dtype=torch.long)
        )
        return DataLoader(index_dataset, batch_size=batch_size, shuffle=shuffle) 
        
    def initialize_dataset(self):
        X = self.adata.X
        if scipy.sparse.issparse(X):
            self.X = X
        else:
            self.X = np.array(X)
        self._type = np.array(self.adata.var['type'].values)
        self.in_dim_SM = np.sum(self._type == "SM")
        self.in_dim_ST = np.sum(self._type == "ST")
        self._n_record = self.X.shape[0]
        self._indices = np.arange(self._n_record)

        if self.batch_keys is not None:
            for key in self.batch_keys:
                if key not in self.adata.obs.columns:
                    raise ValueError(f"batch_key '{key}' not found in AnnData.obs")
            self.batch_categories = [
                pd.Categorical(self.adata.obs[key]) for key in self.batch_keys
            ]
            self.batch_codes = [
                np.array(cat.codes, dtype=np.int64) for cat in self.batch_categories
            ]
            self.n_batch_keys = [len(cat.categories) for cat in self.batch_categories]
        else:
            self.batch_categories = None
            self.batch_codes = None
            self.n_batch_keys = None

       
    def encode(
        self, 
        X: torch.Tensor,
        eps: float = 1e-4
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        X_SM = X[:, self._type == "SM"]
        
        X_ST = X[:, self._type == "ST"]
    
        X_ST = torch.log(X_ST + 1)
        
        
        q_sm = self.encoder_SM.encode(X_SM)
                                    
        q_st = self.encoder_ST.encode(X_ST)
        
        q = torch.hstack((q_sm,q_st))
        
        q_mu = self.z_mean_fc(q)
        q_var = torch.exp(self.z_var_fc(q)) + eps
        z = Normal(q_mu, q_var.sqrt()).rsample()
        H = dict(
            q = q,
            q_mu = q_mu, 
            q_var = q_var,
            z = z
        )

        return H 

    def decode(self, 
        H: Mapping[str, torch.tensor],
        lib_size: torch.tensor, 
        batch_index: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        z = H["z"] # cell latent representation
        
        if batch_index is not None:
            
            z = torch.hstack([
                z, 
                batch_index
            ])
        
        px = self.decoder(z)
        
        h = None
        px_rna_scale = self.px_rna_scale_decoder(px) 
        px_rna_rate = self.px_rna_rate_decoder(px)
        px_rna_dropout = self.px_rna_dropout_decoder(px)  ## In logits
        px_sm_scale = self.px_sm_scale_decoder(px)
        px_sm_rate = self.px_sm_rate_decoder(px)
        px_sm_dropout = self.px_sm_dropout_decoder(px)  ## In logits
        
        px_rna_scale = px_rna_scale * lib_size.unsqueeze(1)
        
        R = dict(
            h = h,
            px = px,
            px_rna_scale = px_rna_scale,
            px_rna_rate = px_rna_rate,
            px_rna_dropout = px_rna_dropout,
            px_sm_scale = px_sm_scale,
            px_sm_rate = px_sm_rate,
            px_sm_dropout = px_sm_dropout
        )
        return R
    
    
    def forward(
        self, 
        X: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None, 
        reduction: str = "sum", 
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:        
        H=self.encode(X)
        q_mu = H["q_mu"]
        q_var = H["q_var"]
        mean = torch.zeros_like(q_mu)
        scale = torch.ones_like(q_var)
        kldiv_loss = kld(Normal(q_mu, q_var.sqrt()),
                         Normal(mean, scale)).sum(dim = 1)

        X_SM = X[:,self._type=="SM"]
        X_ST = X[:,self._type=="ST"]

        R=self.decode(H, X_ST.sum(1), batch_index)
                      
        if self.reconstruction_method_st == 'zinb':
            reconstruction_loss_st = LossFunction.zinb_reconstruction_loss(
                X_ST,
                mu = R['px_rna_scale'],
                theta = R['px_rna_rate'].exp(), 
                gate_logits = R['px_rna_dropout'],
                reduction = reduction
            )
            
        elif self.reconstruction_method_st == 'zg':
            reconstruction_loss_st = LossFunction.zi_gaussian_reconstruction_loss(
                X_ST,
                mean=R['px_rna_scale'],
                variance=R['px_rna_rate'].exp(),
                gate_logits=R['px_rna_dropout'],
                reduction=reduction
            )
        elif self.reconstruction_method_st == 'mse':
            reconstruction_loss_st = nn.functional.mse_loss(
                R['px_rna_scale'],
                X_ST,
                reduction=reduction
            )
        if self.reconstruction_method_sm == 'zg':
            reconstruction_loss_sm = LossFunction.zi_gaussian_reconstruction_loss(
                X_SM,
                mean = R['px_sm_scale'],
                variance = R['px_sm_rate'].exp(),
                gate_logits = R['px_sm_dropout'],
                reduction = reduction
            )
        elif self.reconstruction_method_sm == 'mse':
            reconstruction_loss_sm = nn.MSELoss(reduction='mean')(
                R['px_sm_scale'],
                X_SM,
            )
        elif self.reconstruction_method_sm == "g":
            reconstruction_loss_sm = LossFunction.gaussian_reconstruction_loss(
                X_SM,
                mean = R['px_sm_scale'],
                variance = R['px_sm_rate'].exp(),
                reduction = reduction
            )
            
        loss_record = {
            "reconstruction_loss_sm": reconstruction_loss_sm,
            "reconstruction_loss_st": reconstruction_loss_st,
            "kldiv_loss": kldiv_loss,
        }
        return H, R, loss_record

    def fit(
            self,
            max_epoch:int = 35,
            n_per_batch:int = 128,
            mode: Optional[Literal['single','multi']] = None,
            **kwargs
        ):
            """
            Fits the model.
            
            :param max_epoch: Integer specifying the maximum number of epochs to train the model, default is 35.
            :param n_per_batch: Integer specifying the number of samples per batch, default is 128.
            :param mode: Optional string specifying the mode of training. Can be either 'single' or 'multi', default is None.
            :param reconstruction_reduction: String specifying the reduction method for the reconstruction loss, default is 'sum'.
            :param kl_weight: Float specifying the weight of the KL divergence loss, default is 2.
            :param reconstruction_st_weight: Float specifying the weight of the reconstruction loss for the spatial data, default is 1.
            :param reconstruction_sm_weight: Float specifying the weight of the reconstruction loss for the single-cell multi-omics data, default is 1.
            :param n_epochs_kl_warmup: Integer specifying the number of epochs for KL divergence warmup, default is 400.
            :param optimizer_parameters: Iterable specifying the parameters for the optimizer, default is None.
            :param weight_decay: Float specifying the weight decay for the optimizer, default is 1e-6.
            :param lr: Float specifying the learning rate for the optimizer.
            :param random_seed: Integer specifying the random seed, default is 12.
            :param kl_loss_reduction: String specifying the reduction method for the KL divergence loss, default is 'mean'.

            :return: Dictionary containing the training loss values.
            """
            if mode == 'single':
                kwargs['kl_weight'] = 2.
                kwargs['n_epochs_kl_warmup'] = 35

            elif mode == 'multi':
                kwargs['kl_weight'] = 15.
                kwargs['n_epochs_kl_warmup'] = 0
                
            return self.fit_core(
                max_epoch=max_epoch,
                n_per_batch=n_per_batch,
                **kwargs
            )
                        
[docs]class ConditionalVAESM(nn.Module): """ This class implements a Conditional Variational Autoencoder (CVAE) for vertical integration SM. :param adata: AnnDataSM object containing the SM data. :param hidden_stacks: List of integers specifying the number of hidden units in each stack of the encoder and decoder, default is [128]. :param batch_keys: Optional list of strings specifying the batch keys for batch correction. :param n_latent: Integer specifying the dimensionality of the latent space, default is 10. :param bias: Boolean indicating whether to include bias terms in the linear layers, default is True. :param use_batch_norm: Boolean indicating whether to use batch normalization in the linear layers, default is True. :param use_layer_norm: Boolean indicating whether to use layer normalization in the linear layers, default is False. :param dropout_rate: Float specifying the dropout rate for the linear layers, default is 0.1. :param activation_fn: Callable specifying the activation function to use in the linear layers, default is nn.ReLU. :param device: String or torch.device specifying the device to use for computation, default is "cpu". :param batch_embedding: Literal["embedding", "onehot"] specifying the type of batch embedding to use, default is "onehot". :param encode_libsize: Boolean indicating whether to encode library size information, default is False. :param batch_hidden_dim: Integer specifying the dimensionality of the batch hidden layer, default is 8. :param reconstruction_method: Literal['mse', 'zg', 'zinb'] specifying the reconstruction method, default is 'g'. mse is mean squared error, zg is zero-inflated Gaussian, and zinb is zero-inflated negative binomial. return: Dictionary containing the training loss values. """ def __init__( self, adata: AnnData, hidden_stacks: List[int] = [128], batch_keys: Optional[List[str]] = None, n_latent: int = 10, bias: bool = True, use_batch_norm: bool = True, use_layer_norm: bool = False, dropout_rate: float = 0.1, activation_fn: Callable = nn.ReLU, device: Union[str, torch.device] = "cpu", batch_embedding: Literal["embedding", "onehot"] = "onehot", encode_libsize: bool = False, batch_hidden_dim: int = 8, reconstruction_method: Literal['mse', 'zg', 'zinb','g'] = 'g' ): super(ConditionalVAESM, self).__init__() self.adata = adata self.hidden_stacks = hidden_stacks self.n_hidden = hidden_stacks[-1] self.n_latent = n_latent self.device = device self.reconstruction_method = reconstruction_method self.encode_libsize = encode_libsize self.batch_keys = [batch_keys] if isinstance(batch_keys, str) else batch_keys self.initialize_dataset() self.fcargs = dict( bias = bias, dropout_rate = dropout_rate, use_batch_norm = use_batch_norm, use_layer_norm = use_layer_norm, activation_fn = activation_fn, device = device ) self.encoder = SAE( self.in_dim if not self.encode_libsize else self.in_dim + 1, stacks = hidden_stacks, encode_only = True, **self.fcargs ) self.decoder = FCLayer( in_dim = self.n_latent, out_dim = self.n_hidden, n_cat_list = self.n_batch_keys, cat_dim = batch_hidden_dim, cat_embedding = batch_embedding, use_layer_norm=False, use_batch_norm=True, dropout_rate=0, device=device ) self.encode_libsize = encode_libsize # The latent cell representation z ~ Logisticnormal(0, I) self.z_mean_fc = nn.Linear(self.n_hidden, self.n_latent) self.z_var_fc = nn.Linear(self.n_hidden, self.n_latent) self.px_sm_rate_decoder = nn.Linear( self.n_hidden, self.in_dim ) self.px_sm_scale_decoder = nn.Linear(self.n_hidden, self.in_dim) self.px_sm_dropout_decoder = nn.Linear( self.n_hidden, self.in_dim ) self.to(self.device) def as_dataloader(self, batch_size=32, shuffle=True): index_dataset = torch.utils.data.TensorDataset( torch.tensor(self._indices, dtype=torch.long) ) return DataLoader(index_dataset, batch_size=batch_size, shuffle=shuffle) def initialize_dataset(self): X = self.adata.X if scipy.sparse.issparse(X): self.X = X else: self.X = np.array(X) self._type = np.array(self.adata.var['type'].values) self.in_dim = np.sum(self._type == "SM") self._n_record = self.X.shape[0] self._indices = np.arange(self._n_record) if self.batch_keys is not None: for key in self.batch_keys: if key not in self.adata.obs.columns: raise ValueError(f"batch_key '{key}' not found in AnnData.obs") self.batch_categories = [ pd.Categorical(self.adata.obs[key]) for key in self.batch_keys ] self.batch_codes = [ np.array(cat.codes, dtype=np.int64) for cat in self.batch_categories ] self.n_batch_keys = [len(cat.categories) for cat in self.batch_categories] else: self.batch_categories = None self.batch_codes = None self.n_batch_keys = None def encode( self, X: torch.Tensor, eps: float = 1e-4 ) -> Tuple[torch.Tensor, torch.Tensor]: X = torch.log(X + 1) q = self.encoder.encode(X) q_mu = self.z_mean_fc(q) q_var = torch.exp(self.z_var_fc(q)) + eps z = Normal(q_mu, q_var.sqrt()).rsample() H = dict( q = q, q_mu = q_mu, q_var = q_var, z = z ) return H def decode(self, H: Mapping[str, torch.tensor], lib_size: torch.tensor, batch_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: z = H["z"] if batch_index is not None: z = torch.hstack([ z, batch_index ]) px = self.decoder(z) h = None px_sm_scale = self.px_sm_scale_decoder(px) px_sm_rate = self.px_sm_rate_decoder(px) px_sm_dropout = self.px_sm_dropout_decoder(px) R = dict( h = h, px = px, px_sm_scale = px_sm_scale, px_sm_rate = px_sm_rate, px_sm_dropout = px_sm_dropout ) return R def forward( self, X: torch.Tensor, batch_index: Optional[torch.Tensor] = None, reduction: str = "sum", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: H=self.encode(X) q_mu = H["q_mu"] q_var = H["q_var"] mean = torch.zeros_like(q_mu) scale = torch.ones_like(q_var) if batch_index is not None: mmd_loss = LossFunction.mmd_loss( z = H['q_mu'], cat = batch_index.detach().cpu().numpy(), dim=1, ) else: mmd_loss = torch.tensor(0.0, device=self.device) kldiv_loss = kld(Normal(q_mu, q_var.sqrt()), Normal(mean, scale)).sum(dim = 1) R=self.decode(H, X.sum(1), batch_index) if self.reconstruction_method == 'zinb': reconstruction_loss = LossFunction.zinb_reconstruction_loss( X, mu = R['px_sm_scale'], theta = R['px_sm_rate'].exp(), gate_logits = R['px_sm_dropout'], reduction = reduction ) elif self.reconstruction_method == 'zg': reconstruction_loss = LossFunction.zi_gaussian_reconstruction_loss( X, mean=R['px_sm_scale'], variance=R['px_sm_rate'].exp(), gate_logits=R['px_sm_dropout'], reduction=reduction ) elif self.reconstruction_method == 'mse': reconstruction_loss = nn.functional.mse_loss( R['px_sm_scale'], X, reduction=reduction ) elif self.reconstruction_method == 'g': reconstruction_loss = LossFunction.gaussian_reconstruction_loss( X, mean = R['px_sm_scale'], variance = R['px_sm_rate'].exp(), reduction = reduction ) loss_record = { "reconstruction_loss": reconstruction_loss, "kldiv_loss": kldiv_loss, "mmd_loss": mmd_loss } return H, R, loss_record def fit(self, max_epoch:int = 35, n_per_batch:int = 128, reconstruction_reduction: str = 'sum', kl_weight: float = 15., reconstruction_weight: float = 1., n_epochs_kl_warmup: Union[int, None] = 0, optimizer_parameters: Iterable = None, weight_decay: float = 1e-6, lr: bool = 5e-5, random_seed: int = 12, kl_loss_reduction: str = 'mean', mmd_weight: float = 1., ): """ Fits the model. :param max_epoch: Integer specifying the maximum number of epochs to train the model, default is 35. :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param reconstruction_reduction: String specifying the reduction method for the reconstruction loss, default is 'sum'. :param kl_weight: Float specifying the weight of the KL divergence loss, default is 15. :param reconstruction_weight: Float specifying the weight of the reconstruction loss, default is 1. :param n_epochs_kl_warmup: Integer specifying the number of epochs for KL divergence warmup, default is 0. :param optimizer_parameters: Iterable specifying the parameters for the optimizer, default is None. :param weight_decay: Float specifying the weight decay for the optimizer, default is 1e-6. :param lr: Float specifying the learning rate for the optimizer. :param random_seed: Integer specifying the random seed, default is 12. :param kl_loss_reduction: String specifying the reduction method for the KL divergence loss, default is 'mean'. :return: Dictionary containing the training loss values. """ self.train() if n_epochs_kl_warmup: n_epochs_kl_warmup = min(max_epoch, n_epochs_kl_warmup) kl_warmup_gradient = kl_weight / n_epochs_kl_warmup kl_weight_max = kl_weight kl_weight = 0. if optimizer_parameters is None: optimizer = optim.AdamW(self.parameters(), lr, weight_decay=weight_decay) else: optimizer = optim.AdamW(optimizer_parameters, lr, weight_decay=weight_decay) pbar = get_tqdm()(range(max_epoch), desc="Epoch", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') loss_record = { "reconstruction_loss": 0, "kldiv_loss": 0, "mmd_loss": 0 } epoch_reconstruction_loss_list = [] epoch_kldiv_loss_list = [] epoch_total_loss_list = [] epoch_mmd_loss_list = [] epoch_sm_gate_logits_list = [] for epoch in range(1, max_epoch+1): self._trained = True pbar.desc = "Epoch {}".format(epoch) epoch_total_loss = 0 epoch_reconstruction_loss = 0 epoch_kldiv_loss = 0 epoch_mmd_loss = 0 epoch_sm_gate_logits = [] X_train = self.as_dataloader( batch_size=n_per_batch, shuffle=True ) for batch_idx in X_train: indices = batch_idx[0].cpu().numpy() X_batch = [] for idx in indices: if scipy.sparse.issparse(self.X): x_row = self.X.getrow(idx).toarray().squeeze() else: x_row = self.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H, R, L = self.forward( X_batch, batch_index=batch_index, reduction=reconstruction_reduction, ) epoch_sm_gate_logits.append( R['px_sm_dropout'].detach().cpu().numpy() ) reconstruction_loss = L['reconstruction_loss'] kldiv_loss = L['kldiv_loss'] mmd_loss = L['mmd_loss'] #loss = reconstruction_loss.mean() + kldiv_loss.mean() avg_reconstruction_loss = reconstruction_loss.mean() / n_per_batch avg_mmd_loss = mmd_loss.mean() / n_per_batch if kl_loss_reduction == 'mean': avg_kldiv_loss = kldiv_loss.mean() / n_per_batch elif kl_loss_reduction == 'sum': avg_kldiv_loss = kldiv_loss.sum() / n_per_batch loss = (avg_reconstruction_loss)*reconstruction_weight + \ (avg_kldiv_loss * kl_weight) + \ (avg_mmd_loss * mmd_weight) epoch_reconstruction_loss += avg_reconstruction_loss.item() epoch_kldiv_loss += avg_kldiv_loss.item() epoch_total_loss += loss.item() epoch_mmd_loss += avg_mmd_loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix({ 'reconst': '{:.2e}'.format(epoch_reconstruction_loss), 'kldiv': '{:.2e}'.format(epoch_kldiv_loss), 'total_loss': '{:.2e}'.format(epoch_total_loss), 'mmd_loss': '{:.2e}'.format(epoch_mmd_loss) }) pbar.update(1) epoch_reconstruction_loss_list.append(epoch_reconstruction_loss) epoch_kldiv_loss_list.append(epoch_kldiv_loss) epoch_total_loss_list.append(epoch_total_loss) epoch_sm_gate_logits = np.vstack(epoch_sm_gate_logits) epoch_sm_gate_logits_list.append(epoch_sm_gate_logits) epoch_mmd_loss_list.append(epoch_mmd_loss) if n_epochs_kl_warmup: kl_weight = min( kl_weight + kl_warmup_gradient, kl_weight_max) random_seed += 1 pbar.close() self.trained_state_dict = deepcopy(self.state_dict()) return dict( epoch_reconstruction_loss_list=epoch_reconstruction_loss_list, epoch_kldiv_loss_list=epoch_kldiv_loss_list, epoch_sm_gate_logits_list=epoch_sm_gate_logits_list, epoch_total_loss_list=epoch_total_loss_list, epoch_mmd_loss_list=epoch_mmd_loss_list ) @torch.no_grad() def get_latent_embedding( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True ) -> np.ndarray: """ Get the latent embedding of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the latent embedding. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch_idx in dataloader: indices = batch_idx[0].cpu().numpy() # Read batch from adata.X X_batch = [] for idx in indices: if scipy.sparse.issparse(self.adata.X): x_row = self.adata.X.getrow(idx).toarray().squeeze() else: x_row = self.adata.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # Handle batch codes if applicable if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H = self.encode(X_batch) Zs.append(H[latent_key].detach().cpu().numpy()) if show_progress: pbar.update(1) if show_progress: pbar.close() if hasattr(self, '_shuffle_indices'): # If shuffle indices are set, use them to reorder the results Zs = np.vstack(Zs)[self._shuffle_indices] return Zs @torch.no_grad() def get_normalized_expression( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True ) -> np.ndarray: """ Get the normalized expression of the data. :param latent_key: String specifying the key of the latent variable to return, default is "q_mu". :param n_per_batch: Integer specifying the number of samples per batch, default is 128. :param show_progress: Boolean indicating whether to show the progress bar, default is True. :return: Numpy array containing the normalized expression. """ self.eval() dataloader = self.as_dataloader(batch_size=n_per_batch, shuffle=False) Zs = [] if show_progress: pbar = get_tqdm()(dataloader, desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for batch in dataloader: indices = batch[0].cpu().numpy() # Read batch from adata.X X_batch = [] for idx in indices: if scipy.sparse.issparse(self.adata.X): x_row = self.adata.X.getrow(idx).toarray().squeeze() else: x_row = self.adata.X[idx] X_batch.append(x_row) X_batch = torch.tensor(np.stack(X_batch), dtype=torch.float32).to(self.device) # Handle batch codes if applicable if self.batch_codes is not None: batch_index = [ torch.tensor(code[indices], dtype=torch.long).unsqueeze(1).to(self.device) for code in self.batch_codes ] batch_index = torch.hstack(batch_index) else: batch_index = None H,R,_ = self.forward(X_batch, batch_index=batch_index) Zs.append( np.hstack([ R['px_sm_scale'].detach().cpu().numpy() ]) ) if show_progress: pbar.update(1) if show_progress: pbar.close() Zs = np.vstack(Zs) if hasattr(self, '_shuffle_indices'): # If shuffle indices are set, use them to reorder the results Zs = Zs[self._shuffle_indices] return Zs