Source code for spatialmeta.model._alignment_module

# ------------------------------------------------------------------------
# Copyright (c) 2024 STDI. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable STAlign (https://github.com/JEFworks-Lab/STalign)
# ------------------------------------------------------------------------

from torch.utils.data.sampler import SubsetRandomSampler
from itertools import zip_longest
import einops
from scipy.spatial import Delaunay
from sklearn.preprocessing import MaxAbsScaler
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
from scipy.spatial import Delaunay
import numpy as np
from scipy.spatial import KDTree
import warnings

from torch.distributions import Normal, kl_divergence as kld
from torch.utils.data import DataLoader, SubsetRandomSampler
from typing import Union, Tuple, List, Callable, Iterable, Literal, Optional
from copy import deepcopy
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from itertools import zip_longest
import numpy as np
from scipy.spatial import Delaunay, KDTree
from einops import rearrange
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import issparse

from ._model import SAE, FCLayer, LossFunction, get_tqdm
from ..external.stalign.STalign import (
    rasterize_with_signal,
    extent_from_x,
    to_A,
    interp
)
from ._primitives import Linear
from ..external.stalign.STalign import *
from ..plotting._utils import get_spatial_image, get_spatial_scalefactors_dict
from ..plotting._plotting import create_subplots, create_fig
from ..util._classes import AnnDataST, AnnDataSM

def filter_spatial_outlier_spots(
    coordinates: np.ndarray, 
    subset: bool = True
) -> np.ndarray:
    '''
    Filter out spatial outlier spots using the nearest neighbor method.

    :param coordinates: np.ndarray. The spatial coordinates of the spots, should be [N, 2]
    :param subset: bool. If True, return the subset of the coordinates without the outliers. 
        If False, return the boolean mask of the outliers.

    :return: np.ndarray. The subset of the coordinates without the outliers or the boolean mask of the outliers.
    '''
    neighbors = NearestNeighbors(n_neighbors=100)
    neighbors.fit(coordinates)
    D,I=neighbors.kneighbors(coordinates)
    distance1 = D[:,1:].mean(1)
    Q3 = np.percentile(distance1 , 99)
    outliers1 = (distance1 > Q3)
    distance2 = D[:,1:].min(1)

    Q3 = np.percentile(distance2 , 75)
    Q1 = np.percentile(distance2 , 25)
    IQR = max(np.mean(distance2) * 0.01, Q3-Q1)
    outliers2 = (distance2 > np.mean(distance2) + IQR)
    outliers = outliers1 & outliers2
    if subset:
        return coordinates[~outliers]
    else:
        outliers


def point_alignment_error(pointsI: torch.Tensor, pointsJ: torch.Tensor) -> torch.Tensor:
    """
    Compute the point alignment error between two sets of points.

    :param pointsI: torch.Tensor. The first set of points, should be [N, 2]
    :param pointsJ: torch.Tensor. The second set of points, should be [N, 2]

    :return: torch.Tensor. The point alignment error.
    """

    tree = KDTree(pointsI.detach().cpu().numpy())
    _, I_indices = tree.query(pointsJ.detach().cpu().numpy())

    tree = KDTree(pointsJ.detach().cpu().numpy())
    _, J_indices = tree.query(pointsI.detach().cpu().numpy())

    error = (torch.mean(torch.pow(pointsI - pointsJ[J_indices], 2)) +
             torch.mean(torch.pow(pointsJ - pointsI[I_indices], 2))) / 2
    return error


def nearest_neighbor_torch(
    src: torch.Tensor, 
    dst: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Find the nearest (Euclidean distance) neighbor in dst for each point in src.

    :param src: torch.Tensor. The source points, should be [N, 2]
    :param dst: torch.Tensor. The destination points, should be [N, 2]

    :return: Tuple[torch.Tensor, torch.Tensor]. The indices of the nearest neighbor in dst for each point in src.
    """
    dist = torch.cdist(src, dst)
    min_indices = torch.argmin(dist, dim=1)
    dist_2 = torch.cdist(dst, src)
    min_indices_2 = torch.argmin(dist_2, dim=1)
    return min_indices, min_indices_2


def point_alignment_error_2(
    pointsI: torch.Tensor, 
    pointsJ: torch.Tensor, 
    wi: float = 1, 
    wj: float = 1
) -> torch.Tensor:
    '''
    Compute the point alignment error between two sets of points.

    :param pointsI: torch.Tensor. The first set of points, should be [N, 2]
    :param pointsJ: torch.Tensor. The second set of points, should be [N, 2]
    :param wi: float. The weight for the first set of points.
    :param wj: float. The weight for the second set of points.

    :return: torch.Tensor. The point alignment error.
    '''
    indices1, indices2 = nearest_neighbor_torch(pointsI, pointsJ)
    error = (
        torch.mean(torch.sqrt(torch.pow(pointsI - pointsJ[indices1], 2)))
        * wi
        / pointsI.shape[0]
        + torch.mean(torch.sqrt(torch.pow(pointsJ - pointsI[indices2], 2)))
        * wj
        / pointsJ.shape[0]
    )
    return error


def toR(theta: torch.tensor) -> torch.tensor:
    """
    Construct a 2D rotation matrix from the given angle theta.
    
    :param theta: torch.tensor. A scalar value of angle of rotation.

    :return: torch.tensor. The 2D rotation matrix.
    """
    # Construct rotation matrix
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    # Construct the rotation matrix
    rot_matrix = torch.stack([cos_theta, -sin_theta, sin_theta, cos_theta], dim=0)
    rot_matrix = rot_matrix.view(2, 2).t()
    return rot_matrix


def alpha_shape(
    points: np.ndarray, 
    alpha: float, 
    only_outer: bool = True
):
    """
    Compute the alpha shape (concave hull) of a set of points.
    :param points: np.array of shape (n,2) points.
    :param alpha: alpha value.
    :param only_outer: boolean value to specify if we keep only the outer border or also inner edges.

    :return: set of (i,j) pairs representing edges of the alpha-shape. (i,j) are
        the indices in the points array.
    """
    assert points.shape[0] > 3, "Need at least four points"
    
    def add_edge(edges, i, j):
        """
        Add an edge between the i-th and j-th points,
        if not in the list already
        """
        if (i, j) in edges or (j, i) in edges:
            # already added
            assert (j, i) in edges, "Can't go twice over same directed edge right?"
            if only_outer:
                # if both neighboring triangles are in shape, it's not a boundary edge
                edges.remove((j, i))
            return
        edges.add((i, j))
    tri = Delaunay(points)
    edges = set()
    # Loop over triangles:
    # ia, ib, ic = indices of corner points of the triangle
    for ia, ib, ic in tri.vertices:
        pa = points[ia]
        pb = points[ib]
        pc = points[ic]
        # Computing radius of triangle circumcircle
        # www.mathalino.com/reviewer/derivation-of-formulas/derivation-of-formula-for-radius-of-circumcircle
        a = np.sqrt((pa[0] - pb[0]) ** 2 + (pa[1] - pb[1]) ** 2)
        b = np.sqrt((pb[0] - pc[0]) ** 2 + (pb[1] - pc[1]) ** 2)
        c = np.sqrt((pc[0] - pa[0]) ** 2 + (pc[1] - pa[1]) ** 2)
        s = (a + b + c) / 2.0
        area = np.sqrt(s * (s - a) * (s - b) * (s - c))
        circum_r = a * b * c / (4.0 * area)
        if circum_r < alpha:
            add_edge(edges, ia, ib)
            add_edge(edges, ib, ic)
            add_edge(edges, ic, ia)
    out = np.array([ (points[[i, j], 0],points[[i, j], 1]) for i,j in edges])
    if out.shape[0] == 0:
        return None
    return out


[docs]class AlignmentModule(nn.Module): ''' AlignmentModule is a class for aligning spatial transcriptomics and metabolomics datasets. :param adata_st: AnnDataST. The spatial transcriptomics dataset. :param adata_sm: AnnDataSM. The spatial metabolomics dataset. :param hidden_stacks: List[int]. The hidden layer sizes of the encoder and decoder. :param n_latent: int. The latent dimension. :param bias: bool. If True, use bias in the linear layers. :param use_batch_norm: bool. If True, use batch normalization. :param use_layer_norm: bool. If True, use layer normalization. :param dropout_rate: float. The dropout rate. :param activation_fn: Callable. The activation function. :param device: Union[str, torch.device]. The device to run the model. :param batch_embedding: Literal["embedding", "onehot"]. The batch embedding method. :param encode_libsize: bool. If True, encode the library size. :param batch_hidden_dim: int. The batch hidden dimension. :param reconstruction_method_st: Literal['mse', 'zg', 'zinb']. The reconstruction method for the spatial transcriptomics dataset. :param reconstruction_method_sm: Literal['mse', 'zg', 'g']. The reconstruction method for the spatial metabolomics dataset. ''' def __init__( self, *, adata_st: AnnDataST, adata_sm: AnnDataSM, hidden_stacks: List[int] = [128], n_latent: int = 64, bias: bool = True, use_batch_norm: bool = True, use_layer_norm: bool = False, dropout_rate: float = 0.1, activation_fn: Callable = nn.ReLU, device: Union[str, torch.device] = "cpu", batch_embedding: Literal["embedding", "onehot"] = "onehot", encode_libsize: bool = False, batch_hidden_dim: int = 8, reconstruction_method_st: Literal['mse', 'zg', 'zinb'] = 'zinb', reconstruction_method_sm: Literal['mse', 'zg','g'] = 'g' ): super(AlignmentModule, self).__init__() self.adata_st = adata_st self.adata_sm = adata_sm self.in_dim_st = adata_st.n_vars self.in_dim_sm = adata_sm.n_vars self.hidden_stacks = hidden_stacks self.n_hidden = hidden_stacks[-1] self.n_latent = n_latent self.device = device self.reconstruction_method_st = reconstruction_method_st self.reconstruction_method_sm = reconstruction_method_sm self.encode_libsize = encode_libsize self.initialize_dataset() self.fcargs = dict( bias = bias, dropout_rate = dropout_rate, use_batch_norm = use_batch_norm, use_layer_norm = use_layer_norm, activation_fn = activation_fn, device = device ) self.encoder_st = SAE( self.in_dim_st if not self.encode_libsize else self.in_dim_st + 1, stacks = hidden_stacks, # n_cat_list = [self.n_batch] if self.n_batch > 0 else None, cat_dim = batch_hidden_dim, cat_embedding = batch_embedding, encode_only = True, **self.fcargs ) self.encoder_sm = SAE( self.in_dim_sm, stacks = hidden_stacks, cat_dim = batch_hidden_dim, cat_embedding = batch_embedding, encode_only = True, **self.fcargs ) # self.decoder_n_cat_list = decoder_n_cat_list self.decoder = FCLayer( in_dim = self.n_latent, out_dim = self.n_hidden, #n_cat_list = decoder_n_cat_list, cat_dim = batch_hidden_dim, cat_embedding = batch_embedding, use_layer_norm=False, use_batch_norm=True, dropout_rate=0, device=device ) self.encode_libsize = encode_libsize # The latent cell representation z ~ Logisticnormal(0, I) self.z_mean_st_fc = nn.Linear(self.n_hidden, self.n_latent) self.z_var_st_fc = nn.Linear(self.n_hidden, self.n_latent) self.z_mean_sm_fc = nn.Linear(self.n_hidden, self.n_latent) self.z_var_sm_fc = nn.Linear(self.n_hidden, self.n_latent) self.px_rna_rate_decoder = nn.Linear( self.n_hidden, self.in_dim_st ) self.px_rna_scale_decoder = nn.Sequential( nn.Linear(self.n_hidden, self.in_dim_st), nn.Softmax(dim=-1) ) self.px_rna_dropout_decoder = nn.Linear( self.n_hidden, self.in_dim_st ) self.px_sm_rate_decoder = nn.Linear( self.n_hidden, self.in_dim_sm ) self.px_sm_scale_decoder = nn.Linear(self.n_hidden, self.in_dim_sm) self.px_sm_dropout_decoder = nn.Linear( self.n_hidden, self.in_dim_sm ) self.st_spatial_coord = adata_st.obsm['spatial'] self.sm_spatial_coord = adata_sm.obsm['spatial'] # Alignments self.to(self.device) def initialize_dataset(self): X_st = self.adata_st.X X_sm = self.adata_sm.X self.n_record_st = X_st.shape[0] self.n_record_sm = X_sm.shape[0] self._indices_st = np.arange(self.n_record_st) self._indices_sm = np.arange(self.n_record_sm) _dataset_st = list(X_st) _dataset_sm = list(X_sm) _shuffle_indices_st = list(range(len(_dataset_st))) _shuffle_indices_sm = list(range(len(_dataset_sm))) np.random.shuffle(_shuffle_indices_st) np.random.shuffle(_shuffle_indices_sm) self._dataset_st = np.array([_dataset_st[i] for i in _shuffle_indices_st]) self._dataset_sm = np.array([_dataset_sm[i] for i in _shuffle_indices_sm]) self._shuffle_indices_st = np.array( [x for x,_ in sorted(zip(range(len(_dataset_st)), _shuffle_indices_st), key=lambda x: x[1])] ) self._shuffle_indices_sm = np.array( [x for x,_ in sorted(zip(range(len(_dataset_sm)), _shuffle_indices_sm), key=lambda x: x[1])] ) def as_multi_dataloader( self, n_per_batch: int = 128, subset_indices_st: Union[torch.tensor, np.ndarray] = None, subset_indices_sm: Union[torch.tensor, np.ndarray] = None, train_test_split: bool = False, random_seed: bool = 42, validation_split: bool = .2, shuffle: bool = True, ): indices_st = self._indices_st if subset_indices_st is None else subset_indices_st indices_sm = self._indices_sm if subset_indices_sm is None else subset_indices_sm np.random.seed(random_seed) if shuffle: np.random.shuffle(indices_st) np.random.shuffle(indices_sm) if train_test_split: split_st = int(np.floor(validation_split * self.n_record_st)) split_sm = int(np.floor(validation_split * self.n_record_sm)) if split_st % n_per_batch == 1: n_per_batch += 1 if split_sm % n_per_batch == 1: n_per_batch += 1 train_indices_st, val_indices_st = indices_st[split_st:], indices_st[:split_st] train_indices_sm, val_indices_sm = indices_sm[split_sm:], indices_sm[:split_sm] train_sampler_st = SubsetRandomSampler(train_indices_st) train_sampler_sm = SubsetRandomSampler(train_indices_sm) return { "st": ( DataLoader(indices_st, n_per_batch, sampler=train_sampler_st), DataLoader(indices_st, n_per_batch, sampler=train_sampler_st) ), "sm": ( DataLoader(indices_sm, n_per_batch, sampler=train_sampler_sm), DataLoader(indices_sm, n_per_batch, sampler=train_sampler_sm) ), } else: return { "st": ( DataLoader(indices_st, n_per_batch), ), "sm": ( DataLoader(indices_sm, n_per_batch), ), } def encode( self, batch_data, eps: float = 1e-8 ): st_dict, sm_dict = None, None if batch_data['st'] is not None: q_st = self.encoder_st.encode(batch_data['st']['X']) q_mu_st = self.z_mean_st_fc(q_st) q_var_st = torch.exp(self.z_var_st_fc(q_st)) + eps z_st = Normal(q_mu_st, q_var_st.sqrt()).rsample() st_dict = dict( q = q_st, q_mu = q_mu_st, q_var = q_var_st, z = z_st ) if batch_data['sm'] is not None: q_sm = self.encoder_sm.encode(batch_data['sm']['X']) q_mu_sm = self.z_mean_sm_fc(q_sm) q_var_sm = torch.exp(self.z_var_sm_fc(q_sm)) + eps z_sm = Normal(q_mu_sm, q_var_sm.sqrt()).rsample() sm_dict = dict( q = q_sm, q_mu = q_mu_sm, q_var = q_var_sm, z = z_sm ) H = dict( st = st_dict, sm = sm_dict ) return H def decode( self, H, lib_size:torch.tensor ): if H['st'] is not None: z_vq_st = H['st']['z'] px_st = self.decoder(z_vq_st) px_rna_scale_orig = self.px_rna_scale_decoder(px_st) px_rna_rate = self.px_rna_rate_decoder(px_st) px_rna_dropout = self.px_rna_dropout_decoder(px_st) ## In logits px_rna_scale = px_rna_scale_orig * lib_size.unsqueeze(1) if H['sm'] is not None: z_vq_sm = H['sm']['z'] px_sm = self.decoder(z_vq_sm) px_sm_scale = self.px_sm_scale_decoder(px_sm) px_sm_rate = self.px_sm_rate_decoder(px_sm) px_sm_dropout = self.px_sm_dropout_decoder(px_sm) ## In logits R = dict( st = dict( px_rna_scale_orig = px_rna_scale_orig, px_rna_scale = px_rna_scale, px_rna_rate = px_rna_rate, px_rna_dropout = px_rna_dropout, ) if H['st'] is not None else None, sm = dict( px_sm_scale = px_sm_scale, px_sm_rate = px_sm_rate, px_sm_dropout = px_sm_dropout ) if H['sm'] is not None else None ) return R def forward( self, batch_data, reduction: str = "sum", **kwargs ): reconstruction_loss_sm = torch.tensor(0., device=self.device) reconstruction_loss_st = torch.tensor(0., device=self.device) kldiv_loss_st = torch.tensor(0., device=self.device) kldiv_loss_sm = torch.tensor(0., device=self.device) H=self.encode(batch_data) R=self.decode(H, batch_data['st']['lib_size']) if H['st'] is not None: q_mu_st = H['st']["q_mu"] q_var_st= H['st']["q_var"] mean_st = torch.zeros_like(q_mu_st) scale_st = torch.ones_like(q_var_st) kldiv_loss_st = kld(Normal(q_mu_st, q_var_st.sqrt()), Normal(mean_st, scale_st)).sum(dim = 1) X_st = batch_data['st']['X'] if self.reconstruction_method_st == 'zinb': reconstruction_loss_st = LossFunction.zinb_reconstruction_loss( X_st, mu = R['st']['px_rna_scale'], theta = R['st']['px_rna_rate'].exp(), gate_logits = R['st']['px_rna_dropout'], reduction = reduction ) elif self.reconstruction_method_st == 'zg': reconstruction_loss_st = LossFunction.zi_gaussian_reconstruction_loss( X_st, mean=R['st']['px_rna_scale'], variance=R['st']['px_rna_rate'].exp(), gate_logits=R['st']['px_rna_dropout'], reduction=reduction ) elif self.reconstruction_method_st == 'mse': reconstruction_loss_st = nn.functional.mse_loss( R['st']['px_rna_scale'], X_st, reduction=reduction ) if H['sm'] is not None: q_mu_sm = H['sm']["q_mu"] q_var_sm = H['sm']["q_var"] mean_sm = torch.zeros_like(q_mu_sm) scale_sm = torch.ones_like(q_var_sm) kldiv_loss_sm = kld(Normal(q_mu_sm, q_var_sm.sqrt()), Normal(mean_sm, scale_sm)).sum(dim = 1) X_sm = batch_data['sm']['X'] if self.reconstruction_method_sm == 'zg': reconstruction_loss_sm = LossFunction.zi_gaussian_reconstruction_loss( X_sm, mean = R['sm']['px_sm_scale'], variance = R['sm']['px_sm_rate'].exp(), gate_logits = R['sm']['px_sm_dropout'], reduction = reduction ) elif self.reconstruction_method_sm == 'mse': reconstruction_loss_sm = nn.MSELoss(reduction='mean')( R['sm']['px_sm_scale'], X_sm, ) elif self.reconstruction_method_sm == "g": reconstruction_loss_sm = LossFunction.gaussian_reconstruction_loss( X_sm, mean = R['sm']['px_sm_scale'], variance = R['sm']['px_sm_rate'].exp(), reduction = reduction ) loss_record = { "reconstruction_loss_sm": reconstruction_loss_sm, "reconstruction_loss_st": reconstruction_loss_st, "kldiv_loss_st": kldiv_loss_st, "kldiv_loss_sm": kldiv_loss_sm, } return H, R, loss_record def fit_vae( self, max_epoch:int = 30, n_per_batch:int = 128, kl_weight: float = 2., n_epochs_kl_warmup: Union[int, None] = 400, optimizer_parameters: Iterable = None, weight_decay: float = 1e-6, lr: bool = 5e-5, random_seed: int = 12, validation_split: float = 0.1, ): """ Fit the two VAE models independently for the spatial transcriptomics and metabolomics datasets. :param max_epoch: int. The maximum number of epochs. :param n_per_batch: int. The number of samples per batch. :param kl_weight: float. The weight of the KL divergence loss. :param n_epochs_kl_warmup: Union[int, None]. The number of epochs for KL divergence warmup. :param optimizer_parameters: Iterable. The optimizer parameters. :param weight_decay: float. The weight decay. :param lr: float. The learning rate. :param random_seed: int. The random seed. :param validation_split: float. The validation split. :return: Dict. The loss record. """ self.train() if n_epochs_kl_warmup: n_epochs_kl_warmup = min(max_epoch, n_epochs_kl_warmup) kl_warmup_gradient = kl_weight / n_epochs_kl_warmup kl_weight_max = kl_weight kl_weight = 0. if optimizer_parameters is None: optimizer = optim.AdamW(self.parameters(), lr, weight_decay=weight_decay) else: optimizer = optim.AdamW(optimizer_parameters, lr, weight_decay=weight_decay) pbar = get_tqdm()(range(max_epoch), desc="Epoch", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') loss_record = { "reconstruction_loss_sm": 0, "reconstruction_loss_st": 0, "kldiv_loss": 0, } epoch_reconstruction_loss_st_list = [] epoch_reconstruction_loss_sm_list = [] epoch_kldiv_loss_list = [] for epoch in range(1, max_epoch+1): dataloaders = self.as_multi_dataloader( n_per_batch=n_per_batch, train_test_split = True, validation_split = validation_split, random_seed=random_seed, ) X_train_st = dataloaders['st'][0] X_test_st = dataloaders['st'][1] X_train_sm = dataloaders['sm'][0] X_test_sm = dataloaders['sm'][1] for b, (X_st, X_sm) in enumerate(zip(X_train_st, X_train_sm)): epoch_reconstruction_loss_sm = 0 epoch_reconstruction_loss_st = 0 epoch_kldiv_loss = 0 epoch_total_loss = 0 batch_data = self._prepare_batch(X_st, X_sm) H, R, L = self.forward(batch_data) reconstruction_loss_st = L['reconstruction_loss_st'] reconstruction_loss_sm = L['reconstruction_loss_sm'] kldiv_loss = kl_weight * (L['kldiv_loss_st'].mean() + L['kldiv_loss_sm'].mean()) loss = reconstruction_loss_sm.mean() + reconstruction_loss_st.mean() + kldiv_loss avg_reconstruction_loss_st = reconstruction_loss_st.mean() / n_per_batch avg_reconstruction_loss_sm = reconstruction_loss_sm.mean() / n_per_batch avg_kldiv_loss = kldiv_loss.mean() / n_per_batch epoch_reconstruction_loss_sm += avg_reconstruction_loss_sm.item() epoch_reconstruction_loss_st += avg_reconstruction_loss_st.item() epoch_kldiv_loss += avg_kldiv_loss.item() epoch_total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix({ 'reconst_sm': '{:.2e}'.format(epoch_reconstruction_loss_sm), 'reconst_st': '{:.2e}'.format(epoch_reconstruction_loss_st), 'kldiv': '{:.2e}'.format(epoch_kldiv_loss), }) pbar.update(1) epoch_reconstruction_loss_sm_list.append(epoch_reconstruction_loss_sm) epoch_reconstruction_loss_st_list.append(epoch_reconstruction_loss_st) epoch_kldiv_loss_list.append(epoch_kldiv_loss) if n_epochs_kl_warmup: kl_weight = min( kl_weight + kl_warmup_gradient, kl_weight_max) random_seed += 1 pbar.close() self.trained_state_dict = deepcopy(self.state_dict()) return dict( epoch_reconstruction_loss_st_list=epoch_reconstruction_loss_st_list, epoch_reconstruction_loss_sm_list=epoch_reconstruction_loss_sm_list, epoch_kldiv_loss_list=epoch_kldiv_loss_list, ) @torch.no_grad() def get_latent_embedding( self, latent_key: Literal["z", "q_mu"] = "q_mu", n_per_batch: int = 128, show_progress: bool = True ) -> np.ndarray: self.eval() Zs_st = [] Zs_sm = [] dataloaders = self.as_multi_dataloader( subset_indices_st=list(range(self.n_record_st)), subset_indices_sm=list(range(self.n_record_sm)), n_per_batch=n_per_batch, train_test_split = False, shuffle = False ) X_train_st = dataloaders['st'][0] X_train_sm = dataloaders['sm'][0] if show_progress: pbar = get_tqdm()(max(len(X_train_st), len(X_train_sm)), desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for b, (X_st, X_sm) in enumerate(zip_longest(X_train_st, X_train_sm)): batch_data = self._prepare_batch(X_st, X_sm) H = self.encode(batch_data) if H['st'] is not None: Zs_st.append(H['st'][latent_key].detach().cpu().numpy()) if H['sm'] is not None: Zs_sm.append(H['sm'][latent_key].detach().cpu().numpy()) if show_progress: pbar.update(1) if show_progress: pbar.close() return np.vstack(Zs_st)[self._shuffle_indices_st], np.vstack(Zs_sm)[self._shuffle_indices_sm] @torch.no_grad() def get_st_sm_reconstruction( self, n_per_batch: int = 128, reconstruction_key: Literal["px_scale", "px_rate", "px_dropout"] = "px_scale", show_progress: bool = True ): self.eval() Zs_st = [] Zs_sm = [] dataloaders = self.as_multi_dataloader( subset_indices_st=list(range(self.n_record_st)), subset_indices_sm=list(range(self.n_record_sm)), n_per_batch=n_per_batch, train_test_split = False, shuffle=False ) X_train_st = dataloaders['st'][0] X_train_sm = dataloaders['sm'][0] if show_progress: pbar = get_tqdm()(max(len(X_train_st), len(X_train_sm)), desc="Latent Embedding", bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') for b, (X_st, X_sm) in enumerate(zip_longest(X_train_st, X_train_sm)): batch_data = self._prepare_batch(X_st, X_sm) H,R,L = self.forward(batch_data) if H['st'] is not None: Zs_st.append( R['st'][reconstruction_key.split("_")[0] + '_rna_' + reconstruction_key.split("_")[1] + '_orig'].detach().cpu().numpy() ) if H['sm'] is not None: Zs_sm.append(R['sm'][reconstruction_key.split("_")[0] + '_sm_' + reconstruction_key.split("_")[1]].detach().cpu().numpy()) if show_progress: pbar.update(1) if show_progress: pbar.close() return np.vstack(Zs_st)[self._shuffle_indices_st], np.vstack(Zs_sm)[self._shuffle_indices_sm] def get_rasterized_feature_map(self) -> dict: adata_st = self.adata_st adata_sm = self.adata_sm J, scale = get_spatial_image(adata_st) x_st, y_st = self.st_spatial_coord[:, 0] * scale, self.st_spatial_coord[:, 1] * scale x_sm, y_sm = self.sm_spatial_coord[:, 0], self.sm_spatial_coord[:, 1] z_st, z_sm = self.get_latent_embedding(latent_key="q_mu") output_sm = rasterize_with_signal(x_sm, y_sm, z_sm, dx=0.4, blur=0.01) xJ = [ torch.arange(J.shape[1], device=self.device), torch.arange(J.shape[0], device=self.device), ] xI = [ torch.tensor(output_sm[0], device=self.device), torch.tensor(output_sm[1], device=self.device), ] I = ( torch.nn.functional.normalize( torch.tensor(output_sm[2], dtype=torch.float32, device=self.device), p=2.0, dim=0, ) * 100 ) pointsI = torch.tensor(np.vstack([x_sm, y_sm]).T, dtype=torch.float32) pointsJ = torch.tensor(np.vstack([x_st, y_st]).T, dtype=torch.float32) maskI = torch.from_numpy(self.inside_delaunay_mask([xI[0].detach().cpu().numpy(),xI[1].detach().cpu().numpy()], pointsI)).to(self.device) maskJ = torch.from_numpy(self.inside_delaunay_mask([xJ[0].detach().cpu().numpy(),xJ[1].detach().cpu().numpy()], pointsJ)).to(self.device) pointsI = pointsI.to(self.device) pointsJ = pointsJ.to(self.device) J = J / 255 J = torch.tensor( np.dot(J[...,:3], [0.2989, 0.5870, 0.1140]) ) # Gray scale J = J.unsqueeze(0) J = J.to(torch.float32).to(self.device) coordinates = np.indices(np.array(J.shape)[[2,1]]).transpose(2,1,0) coordinates_flat = einops.rearrange(coordinates, 'a b c -> (a b) c') neighbors = NearestNeighbors(n_neighbors=1) neighbors.fit(self.st_spatial_coord * scale) distances, indices = neighbors.kneighbors(coordinates_flat) G = einops.rearrange(z_st[indices.flatten()], '(w h) c-> h w c', w = J.shape[-2], h = J.shape[-1]) maskJ = einops.rearrange(distances>10,'(w h) c-> h w c', w = J.shape[-2], h = J.shape[-1])[:,:,0] G[maskJ]=0 maskJ = torch.from_numpy(maskJ).to(self.device) G = torch.from_numpy(G).to(self.device) G = einops.rearrange(G, 'w h c -> c h w') fig, axes = create_subplots(2,2,figsize=(10,10)) axes=axes.flatten() axes[0].imshow(I.mean(0).detach().cpu().numpy()) axes[0].set_title('SM rasterized latent feature') axes[1].scatter( adata_sm.obsm['spatial'][:,0],-adata_sm.obsm['spatial'][:,1], c=z_sm.mean(-1), s=0.7 ) axes[1].set_title('SM latent feature') axes[2].imshow(J.mean(0).detach().cpu().numpy()) axes[2].set_title('ST spatial image') axes[3].imshow(G.detach().cpu().numpy().mean(0)) axes[3].set_title('ST latent feature') return dict( data = dict( I = I, # intend to be the SM feature in the image space J = J, # intend to be the the histology image G = G, # intend to be the ST feature in the image space as J maskI = maskI, # mask for the SM feature maskJ = maskJ, # mask for the ST feature pointsI = pointsI, # spatial coordinates for the SM feature pointsJ = pointsJ, # spatial coordinates for the ST feature, z_st = z_st, z_sm = z_sm, x_st = x_st, y_st = y_st, x_sm = x_sm, y_sm = y_sm, scale = scale, xI = xI, xJ = xJ, ), fig = dict( feature = fig ) ) def random_sample_inside_image_spot( self, data, threshold1: float = 0.65, threshold2: float = 0.2 ): J = data['data']['J'] maskJ = data['data']['maskJ'] x_sm, x_st = data['data']['x_sm'], data['data']['x_st'] y_sm, y_st = data['data']['y_sm'], data['data']['y_st'] scale = data['data']['scale'] z_st, z_sm = data['data']['z_st'], data['data']['z_sm'] pointsI = data['data']['pointsI'] pointsJ = data['data']['pointsJ'] adata_st = self.adata_st Jtrue = (J[0] <= J[0][maskJ.T].mean() * threshold1).detach().cpu().numpy() fig,axes=create_subplots(1,3, figsize=(10,5)) fig.set_size_inches=(15,5) axes[2].imshow(Jtrue) w,h = J.shape[1], J.shape[2] spot_diamter_in_pixel = (get_spatial_scalefactors_dict(adata_st)['spot_diameter_fullres'] * scale) spot_gap_in_pixel = spot_diamter_in_pixel * (110 / 65) ws = np.round(w / spot_gap_in_pixel) hs = np.round(h / spot_gap_in_pixel) xv,yv=np.meshgrid(np.arange(ws) * spot_gap_in_pixel, np.arange(hs) * spot_gap_in_pixel) for i in range(0,yv.shape[1],2): yv[:,i] += spot_gap_in_pixel / 2 xv,yv=xv.flatten(),yv.flatten() spot_coordinate = np.array([xv,yv]).T coordinates = np.indices(( w,h )).transpose(1, 2, 0) coordinates_flat = einops.rearrange(coordinates, 'a b c -> (b a) c') neighbors = NearestNeighbors(n_neighbors=1) neighbors.fit(spot_coordinate) distance, indices = neighbors.kneighbors(coordinates_flat) mask_d = (distance.min(1) < spot_diamter_in_pixel / 2).reshape(( h, w )).astype(np.uint8) indices_mapping = {k:[] for k in range(spot_coordinate.shape[0])} for i,j,k in zip(indices.flatten(),mask_d.flatten(),coordinates_flat): if j: indices_mapping[i].append(k) indices_mapping = {k:np.vstack(v) for k,v in indices_mapping.items()} spot_coordinate = spot_coordinate[[Jtrue[None,None,:,:][:,:,v[:,0],v[:,1]].mean() > threshold2 for v in indices_mapping.values()]] spot_coordinate = spot_coordinate[:,[1,0]] spot_coordinate = filter_spatial_outlier_spots(spot_coordinate) spot_coordinate = torch.tensor(spot_coordinate,dtype=torch.float32).to(self.device) initial_scale = (spot_coordinate[:,0].max() - spot_coordinate[:,1].min()) / (x_sm.max() - x_sm.min()) alpha = 1 pointsI_edge = None while pointsI_edge is None: pointsI_edge = alpha_shape(pointsI.detach().cpu().numpy(), alpha=alpha) alpha += 1 pointsI_edge = torch.from_numpy(pointsI_edge).to(pointsI.dtype).to(pointsI.device) pointsI_edge = einops.rearrange(pointsI_edge, 'a b c -> (a c) b') alpha = 1 pointsJ_edge = None while pointsJ_edge is None: pointsJ_edge = alpha_shape(pointsJ.detach().cpu().numpy(), alpha=alpha) alpha += 1 pointsJ_edge = torch.from_numpy(pointsJ_edge).to(pointsJ.dtype).to(pointsJ.device) pointsJ_edge = einops.rearrange(pointsJ_edge, 'a b c -> (a c) b') alpha = 1 spot_coordinate_edge = None while spot_coordinate_edge is None: spot_coordinate_edge = alpha_shape(spot_coordinate.detach().cpu().numpy(), alpha=alpha) alpha += 1 spot_coordinate_edge = torch.from_numpy(spot_coordinate_edge).to(spot_coordinate.dtype).to(spot_coordinate.device) spot_coordinate_edge = einops.rearrange(spot_coordinate_edge, 'a b c -> (a c) b') axes[0].imshow(J[0].detach().cpu().numpy()) axes[0].set_title('ST spatial image') axes[1].imshow(J[0].detach().cpu().numpy()) axes[0].scatter( adata_st.obsm['spatial'][:,0] * scale, adata_st.obsm['spatial'][:,1] * scale, s=3, lw=0, c=z_st.mean(-1), cmap='Reds', marker='s' ) axes[0].imshow(maskJ.T.detach().cpu().numpy(), alpha=0.3) axes[1].scatter( spot_coordinate[:,0].detach().cpu().numpy(), spot_coordinate[:,1].detach().cpu().numpy(), s=4, lw=0, c='white' ) axes[1].set_title('Randomly sampled inside-tissue spots') data['fig']['spot_coordinate'] = fig data['data']['spot_coordinate'] = spot_coordinate data['data']['initial_scale'] = initial_scale data['data']['spot_coordinate_edge'] = spot_coordinate_edge data['data']['pointsI_edge'] = pointsI_edge data['data']['pointsJ_edge'] = pointsJ_edge return data def _prepare_batch( self, X_st, X_sm, ): stdict, smdict = None, None if X_st is not None: x_st = self._dataset_st[X_st.cpu().numpy()] x_st = torch.tensor( np.vstack(list(map(lambda x: x.toarray() if issparse(x) else x, x_st))) ) x_st = x_st.to(self.device) lib_size_st = x_st.sum(1).to(self.device) coord_st = torch.tensor( self.st_spatial_coord[X_st.detach().cpu().numpy()], device=self.device ).to(torch.float32) stdict = dict( X=x_st, lib_size=lib_size_st, spatial_coord=coord_st ) if X_sm is not None: x_sm = self._dataset_sm[X_sm.cpu().numpy()] x_sm = torch.tensor( np.vstack(list(map(lambda x: x.toarray() if issparse(x) else x, x_sm))) ) x_sm = x_sm.to(self.device) lib_size_sm = x_sm.sum(1).to(self.device) coord_sm = torch.tensor( self.sm_spatial_coord[X_sm.detach().cpu().numpy()], device=self.device ).to(torch.float32) smdict = dict( X=x_sm, lib_size=lib_size_sm, spatial_coord=coord_sm ) return dict( st=stdict, sm=smdict ) def inside_delaunay_mask(self, xI, pointsI): # Compute Delaunay triangulation tri = Delaunay(pointsI) # Create a mask for the image mask = np.zeros((xI[0].shape[0], xI[1].shape[0]), dtype=bool) # Generate all pixel coordinates xx, yy = np.meshgrid(xI[0],xI[1]) pixel_coordinates = np.vstack((xx.flatten(), yy.flatten())).T # Find simplex for each pixel simplex_indices = tri.find_simplex(pixel_coordinates) # Mark pixels inside the Delaunay hull mask.T[(simplex_indices >= 0).reshape(len(xI[1]),len(xI[0]))] = True return mask def fit_alignment( self, data: dict, initial_scale: bool = None, a: float = 50.0, p: float = 2.0, expand: float = 2.0, nt: float = 3, niter: int = 500, diffeo_start: float = 0, diffeo: bool = False, epV: float = 2e-1, sigmaM: float = 1.0, sigmaR: float = 5e5, align_sm_spot_to: Literal["histology", "ST"] = "histology", align_spot_outline: bool = True, align_sm_feature_to_st_feature: bool = False, align_sm_feature_to_histology_feature: bool = False, debug_path: Optional[str] = None, ): """ Fit the alignment model for spatial transcriptomics (ST) features and spatial metabolomics (SM) features based on computation of Affine Matrix A by stochastic gradient descent See the original implementation at https://github.com/JEFworks-Lab/STalign :param data: dict. The data dictionary containing features of ST and SM to be aligned. :param initial_scale: bool. The initial scale for the image :param a: float. Smoothness scale of velocity field. :param p: float. Power of Laplacian in velocity regularization. :param expand: float. The expansion factor. :param nt: float. Number of timesteps for integrating velocity field. :param niter: int. The number of iterations. :param diffeo_start: float. The starting step of diffeomorphism. :param epV: float. Gradient descent step size for velocity field. The default value was set to a small value (2e-1) to avoid divergence, compared to the original implementation so the user may need to adjust this value to allow velocity field to converge. :param sigmaM: float. Standard deviation of image matching term for Gaussian mixture modeling in cost function. :param sigmaR: float. Standard deviation of regularization term for Gaussian mixture modeling in cost function. :param align_sm_spot_to: bool. Whether to align the SM spot to spots sampled from the histology image or ST spots. Should be either 'histology' or 'ST'. :param align_spot_outline: bool. Whether to align the spot outline or all spots :param align_sm_feature_to_histology_feature: bool. Whether to align the SM feature to histology gray scaled feature. :param align_sm_feature_to_st_feature: bool. Whether to align the ST feature with SM feature. :param debug_path: str. The optional temporary file path that save intermediate results in alignment. """ assert align_sm_spot_to in ['histology','ST'], "align_sm_spot_to must be either 'histology' or 'ST'" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) return self._fit_alignment_impl( data, initial_scale=initial_scale, a=a, p=p, expand=expand, nt=nt, niter=niter, diffeo_start=diffeo_start, diffeo=diffeo, epV=epV, sigmaM=sigmaM, sigmaR=sigmaR, align_sm_feature_to_histology_feature=align_sm_feature_to_histology_feature, align_sm_spot_to=align_sm_spot_to, align_sm_feature_to_st_feature=align_sm_feature_to_st_feature, align_spot_outline=align_spot_outline, debug_path = debug_path, ) def _fit_alignment_impl( self, data: dict, initial_scale: bool = None, a: float = 50.0, p: float = 2.0, expand: float = 2.0, nt: float = 3, niter: int = 500, diffeo_start: float = 0, diffeo: bool = False, epV: float = 2e-1, sigmaM: float = 1.0, sigmaR: float = 5e5, align_sm_feature_to_histology_feature: bool = False, align_sm_spot_to: Literal['histology','ST'] = 'histology', align_spot_outline: bool = True, align_sm_feature_to_st_feature: bool = False, debug_path: Optional[str] = None ): self.eval() tosave = [0,0,0] if debug_path is not None and os.path.exists(debug_path): os.system(f"rm -rf {debug_path}/*png") initial_scale = data['data']['initial_scale'] if initial_scale is None else initial_scale spot_coordinate = data['data']['spot_coordinate'] # spatial coordinate of the spot sampled from the image spot_coordinate_edge = data['data']['spot_coordinate_edge'] # edge of spatial coordinate from spots sampled from the image pointsI_edge = data['data']['pointsI_edge'] # edge of spatial coordinate from ST spots pointsJ_edge = data['data']['pointsJ_edge'] # edge of spatial coordinate pointsI = data['data']['pointsI'] # spatial coordinate of the ST spots pointsJ = data['data']['pointsJ'] # spatial coordinate of the SM spots spot_coordinate = spot_coordinate.to(self.device) spot_coordinate_edge = spot_coordinate_edge.to(self.device) pointsI_edge = pointsI_edge.to(self.device) pointsJ_edge = pointsJ_edge.to(self.device) pointsI = pointsI.to(self.device) pointsJ = pointsJ.to(self.device) I = data['data']['I'] # SM feature rasterized in the image space J = data['data']['J'] # Image feature G = data['data']['G'] # ST feature rasterized in the image space maskJ = data['data']['maskJ'] # tissue mask for the ST feature init_bias = torch.tensor(1.) xI = data['data']['xI'] # x coordinate for the SM feature xJ = data['data']['xJ'] # x coordinate for the ST feature maskI = data['data']['maskI'] # tissue mask for the SM feature x_sm = data['data']['x_sm'] # x coordinate for the SM feature y_sm = data['data']['y_sm'] # y coordinate for the SM feature x_st = data['data']['x_st'] # x coordinate for the ST feature y_st = data['data']['y_st'] # y coordinate for the ST feature scale = data['data']['scale'] # estimated scale of distortion if align_sm_feature_to_st_feature: if align_sm_feature_to_histology_feature: J = torch.cat([G,J], dim=0) else: J = G # projection layer of the latent space latent_adaptor_last = Linear( I.shape[0], J.shape[0], init='final' ).to(self.device) # project SM feature to histology image feature latent_adaptor_last.bias = torch.nn.Parameter( torch.tensor([init_bias], device=self.device, requires_grad=True) ) latent_adaptor = nn.Sequential( Linear(I.shape[0], I.shape[0], init='normal'), nn.ReLU(), nn.LayerNorm(I.shape[0]), nn.Dropout(0.1), latent_adaptor_last ).to(self.device) latent_optimizer = torch.optim.AdamW( latent_adaptor.parameters(), lr=1e-3 ) if align_sm_feature_to_st_feature: latent_adaptor_1_rev = nn.Linear(J.shape[0],I.shape[0]).to(self.device) latent_adaptor_2_last = Linear( J.shape[0], J.shape[0], init='final' ).to(self.device) latent_adaptor_2_last.bias = torch.nn.Parameter( torch.tensor( [init_bias], device=self.device, requires_grad=True ) ) latent_adaptor_2 = nn.Sequential( Linear(J.shape[0], J.shape[0], init='normal'), nn.ReLU(), nn.LayerNorm(J.shape[0]), nn.Dropout(0.1), latent_adaptor_2_last ).to(self.device) latent_adaptor_2_rev = nn.Linear(J.shape[0],J.shape[0]).to(self.device) from itertools import chain latent_optimizer = torch.optim.AdamW(chain( latent_adaptor.parameters(), latent_adaptor_1_rev.parameters(), latent_adaptor_2.parameters(), latent_adaptor_2_rev.parameters(), ), lr=1e-3) theta = torch.tensor(0,device=self.device, dtype=torch.float32, requires_grad=True) # L = torch.eye(2, device=self.device, dtype=torch.float32, requires_grad=True) T = torch.zeros(2, device=self.device, dtype=torch.float32, requires_grad=True) S = torch.tensor( [ torch.log(torch.tensor(initial_scale)), torch.log(torch.tensor(initial_scale)), ], device=self.device, dtype=torch.float32, requires_grad=True, ) # scale = torch.tensor(1.0, device=self.device, dtype=torch.float32, requires_grad=True) minv = torch.as_tensor([x[0] for x in xI], device=self.device, dtype=torch.float32) maxv = torch.as_tensor([x[-1] for x in xI], device=self.device, dtype=torch.float32) minv, maxv = (minv + maxv) * 0.5 + 0.5 * torch.tensor( [-1.0, 1.0], device=self.device, dtype=torch.float32 )[..., None] * (maxv - minv) * expand xv = [ torch.arange(m, M, a * 0.5, device=self.device, dtype=torch.float32) for m, M in zip(minv, maxv) ] XV = torch.stack(torch.meshgrid(xv), -1) v = torch.zeros( (nt, XV.shape[0], XV.shape[1], XV.shape[2]), device=self.device, dtype=torch.float32, requires_grad=True, ) dv = torch.as_tensor([x[1] - x[0] for x in xv], device=self.device, dtype=torch.float32) fv = [ torch.arange(n, device=self.device, dtype=torch.float32) / n / d for n, d in zip(XV.shape, dv) ] FV = torch.stack(torch.meshgrid(fv), -1) LL = ( 1.0 + 2.0 * a**2 * torch.sum((1.0 - torch.cos(2.0 * np.pi * FV * dv)) / dv**2, -1) ) ** (p * 2.0) K = 1.0 / LL DV = torch.prod(dv) WM = torch.ones(J[0].shape, dtype=J.dtype, device=J.device) * 0.5 WB = torch.ones(J[0].shape, dtype=J.dtype, device=J.device) * 0.4 WA = torch.ones(J[0].shape, dtype=J.dtype, device=J.device) * 0.1 xI = [torch.tensor(x, device=self.device, dtype=torch.float32) for x in xI] xJ = [torch.tensor(x, device=self.device, dtype=torch.float32) for x in xJ] XI = torch.stack(torch.meshgrid(*xI, indexing="ij"), -1) XJ = torch.stack(torch.meshgrid(*xJ, indexing="ij"), -1) dJ = [x[1] - x[0] for x in xJ] extentJ = ( xJ[1][0].item() - dJ[1].item() / 2.0, xJ[1][-1].item() + dJ[1].item() / 2.0, xJ[0][-1].item() + dJ[0].item() / 2.0, xJ[0][0].item() - dJ[0].item() / 2.0, ) if os.path.exists(debug_path): os.system(f"rm {debug_path}" + '/*') fit_history = {} pbar = get_tqdm()(range(niter), desc="LLDDMM", bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}') for it in range(niter): if it % 10 == 0 and it > 0 and (align_sm_feature_to_histology_feature or align_sm_feature_to_st_feature): XII = einops.rearrange(XI.detach().cpu().numpy(), 'w h c -> (h w) c') neighbors = NearestNeighbors(n_neighbors=1) neighbors.fit(XII) _, indices = neighbors.kneighbors(einops.rearrange( Xs.detach().cpu().numpy(), 'w h c -> (w h) c' )) AI = einops.rearrange( einops.rearrange(I, 'c w h -> (w h) c')[indices], '(w h) 1 c -> c h w', w = wj, h = hj ) for _ in range(100): fAIt = latent_adaptor(einops.rearrange(AI.detach(), ' c w h -> h w c')) if align_sm_feature_to_st_feature: JI = einops.rearrange(J.detach(), 'c w h -> h w c') JIt = latent_adaptor_2(JI) fAItr = latent_adaptor_1_rev(fAIt) JItr = latent_adaptor_2_rev(JIt) EM = torch.sum(einops.rearrange(( JIt - fAIt ), 'w h c -> c w h') ** 2 * WM.T.detach() * ~maskJ.detach() ) * 10 / sigmaM**2 EL = (torch.sum(einops.rearrange(( JI - JItr ), 'w h c -> c w h') ** 2 * WM.T.detach() * ~maskJ.detach() ) + torch.sum(einops.rearrange(( einops.rearrange(AI, 'c w h -> h w c') - fAItr ), 'w h c -> c w h') ** 2)) * 10 / sigmaM**2 EML = EM + EL EML.backward() else: EM = torch.sum(einops.rearrange(( einops.rearrange(J.detach(), 'c w h -> h w c') - fAIt ), 'w h c -> c w h') ** 2 * WM.T.detach() * ~maskJ.detach() ) * 10 / sigmaM**2 EM.backward() latent_optimizer.step() latent_optimizer.zero_grad() L = toR(theta) A = to_A(L, T) # Ai Li = torch.linalg.inv(L) wj, hj = XJ.shape[:2] pointsIt = torch.clone(pointsI).to(J.device) pointsIt_bias = pointsIt.mean(0) Xs = (((einops.rearrange(XJ, 'w h c -> (w h) c') - pointsIt_bias) - A[:2, -1]) @ L) Xs = einops.rearrange(Xs, '(w h) c -> w h c', w = wj, h = hj) # now diffeo, not semilagrange here for t in range(nt - 1, -1, -1): Xs = ( Xs + interp(xv, -v[t].permute(2, 0, 1), Xs.permute(2, 0, 1)).permute( 1, 2, 0 ) / nt ) Xs /= (torch.e ** S) Xs += pointsIt_bias # and points if pointsIt.shape[0] > 0: for t in range(nt): pointsIt += ( interp(xv, v[t].permute(2, 0, 1), pointsIt.T[..., None])[..., 0].T / nt ) pointsIt -= pointsIt_bias pointsIt *= (torch.e ** S) if pointsIt.shape[0] > 0: pointsIt = (A[:2, :2] @ pointsIt.T + A[:2, -1][..., None]).T pointsIt += pointsIt_bias pointsI_edget = torch.clone(pointsI_edge).to(J.device) if pointsI_edget.shape[0] > 0: for t in range(nt): pointsI_edget += ( interp(xv, v[t].permute(2, 0, 1), pointsI_edget.T[..., None])[..., 0].T / nt ) pointsI_edget -= pointsIt_bias pointsI_edget *= (torch.e ** S) if pointsI_edget.shape[0] > 0: pointsI_edget = (A[:2, :2] @ pointsI_edget.T + A[:2, -1][..., None]).T pointsI_edget += pointsIt_bias xIs = [ xI[0].detach(), xI[1].detach() ] EM = torch.tensor(0., device=self.device) if it > 0 and it % 10 == 0 and (align_sm_feature_to_histology_feature or align_sm_feature_to_st_feature): XII = einops.rearrange(XI.detach().cpu().numpy(), 'w h c -> (h w) c') neighbors = NearestNeighbors(n_neighbors=1) neighbors.fit(XII) _, indices = neighbors.kneighbors(einops.rearrange( Xs.detach().cpu().numpy(), 'w h c -> (w h) c' )) AI = einops.rearrange( einops.rearrange(I, 'c w h -> (w h) c')[indices], '(w h) 1 c -> c h w', w = wj, h = hj ) # transform the contrast B = torch.ones( 1 + AI.shape[0], AI.shape[1] * AI.shape[2], device=AI.device, dtype=AI.dtype ) B[1 : AI.shape[0] + 1] = AI.reshape(AI.shape[0], -1) with torch.no_grad(): BB = B @ (B * WM.ravel()).T BJ = B @ ((J * WM).reshape(J.shape[0], J.shape[1] * J.shape[2])).T small = 0.1 coeffs = torch.linalg.solve( BB + small * torch.eye(BB.shape[0], device=BB.device, dtype=BB.dtype), BJ, ) fAI = ((B.T @ coeffs).T).reshape(J.shape) fAIt = latent_adaptor(einops.rearrange(AI, ' c w h -> h w c')) if align_sm_feature_to_st_feature: JI = einops.rearrange(J.detach(), 'c w h -> h w c') JIt = latent_adaptor_2(JI) EM = torch.mean(einops.rearrange(( JIt - fAIt ), 'w h c -> c w h') ** 2 * WM.T * ~maskJ ) * 10 / sigmaM**2 else: EM = torch.mean(einops.rearrange(( einops.rearrange(J, 'c w h -> h w c') - fAIt ), 'w h c -> c w h') ** 2 * WM.T * ~maskJ ) * 10 / sigmaM**2 # ER is the regularization term ER = ( torch.mean( torch.sum(torch.abs(torch.fft.fftn(v, dim=(1, 2))) ** 2, dim=(0, -1)) * LL ) * DV / 2.0 / v.shape[1] / v.shape[2] / sigmaR ** 2 ) if align_sm_spot_to == 'histology': if align_spot_outline: EP = point_alignment_error_2( pointsI_edget, spot_coordinate_edge ) * 10 else: EP = point_alignment_error_2( pointsIt, spot_coordinate ) elif align_sm_spot_to == 'ST': if align_spot_outline: EP = point_alignment_error_2( pointsI_edget, pointsJ_edge ) * 10 else: EP = point_alignment_error_2( pointsIt, pointsJ ) if it > 0 and it % 10 == 0: E = EP + ER + EM else: E = EP + ER tosave = [EM.item() if EM.item() != 0 else tosave[0] , ER.item(), EP.item()] if debug_path is not None and it % 10 == 0: fig,axes=plt.subplots(2,2,figsize=(12,12)) axes=axes.flatten() axes[0].scatter( pointsIt[:, 0].detach().cpu().numpy(), pointsIt[:, 1].detach().cpu().numpy(), c="r", s=1, label="I", ) axes[0].scatter( pointsJ[:, 0].detach().cpu().numpy(), pointsJ[:, 1].detach().cpu().numpy(), c="b", s=1, label="J", ) axes[0].set_title(f"point alignment error: {EP.item()}") axes[0].invert_yaxis() if it > 0 and it % 10 == 0 and (align_sm_feature_to_histology_feature or align_sm_feature_to_st_feature): axes[1].imshow( einops.rearrange(fAIt, 'w h c -> h w c').mean(-1).detach().cpu().numpy(), #vmin=J.min(), #vmax=J.max() ) axes[1].set_title(f'EM: {EM.item()}') axes[2].scatter( pointsIt[:, 0].detach().cpu().numpy(), pointsIt[:, 1].detach().cpu().numpy(), c="r", s=1, label="I", ) axes[2].scatter( pointsI_edget[:, 0].detach().cpu().numpy(), pointsI_edget[:, 1].detach().cpu().numpy(), c="orange", s=1, label="I", ) axes[2].scatter( spot_coordinate[:, 0].detach().cpu().numpy(), spot_coordinate[:, 1].detach().cpu().numpy(), c="black", s=1, label="J", ) axes[2].scatter( spot_coordinate_edge[:, 0].detach().cpu().numpy(), spot_coordinate_edge[:, 1].detach().cpu().numpy(), c="blue", s=1, label="J", ) axes[2].set_title(f"point alignment error: {EP.item()}") axes[2].invert_yaxis() axes[3].imshow( J.mean(0).detach().cpu().numpy(), ) axes[3].set_title("J") axes[1].set_xbound(axes[3].get_xbound()) axes[1].set_ybound(axes[3].get_ybound()) if debug_path is not None: fig.savefig( os.path.join( debug_path, f'point_alignment_{it}.png' ) ) plt.close() postfix = dict(zip(['EM','ER','EP'], tosave)) postfix['angle'] = theta.item() pbar.set_postfix(postfix) E.backward() with torch.no_grad(): if it % 10 == 0 and it > 0: latent_optimizer.step() latent_optimizer.zero_grad() if not torch.isnan(theta.grad).any() or not torch.isnan(T.grad).any(): theta -= (5e-2 / (1.0 + (it >= diffeo_start) * 9)) * theta.grad if not torch.isnan(T.grad).any(): T -= 2000 * T.grad if not torch.isnan(S.grad).any(): S -= S.grad * 1e-4 T.grad.zero_() # theta += torch.rand(theta.shape).to(S.device) * 0.05 theta.grad.zero_() # S += torch.rand(S.shape).to(S.device) * 0.05 S.grad.zero_() # v grad vgrad = v.grad if not torch.isnan(vgrad).any() and diffeo: # smooth it vgrad = torch.fft.ifftn( torch.fft.fftn(vgrad, dim=(1, 2)) * K[..., None], dim=(1, 2) ).real if it >= diffeo_start: v -= vgrad * epV v.grad.zero_() A = to_A(L, T) fit_history[it] = dict( loss = tosave, theta = theta.item(), T = T.detach().cpu().numpy(), S = S.detach().cpu().numpy(), v = v.detach().cpu().numpy(), A = A.detach().cpu().numpy(), L = L.detach().cpu().numpy(), pointsIt = pointsIt.detach().cpu().numpy(), ) pbar.update() pbar.close() return fit_history