from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union, Tuple
import matplotlib
import matplotlib.pyplot as plt
import pyimzml
from pyimzml.ImzMLParser import ImzMLParser
import pandas as pd
import anndata as AnnData
import matplotlib.patches as mpatch
from adjustText import adjust_text
import seaborn as sns
import numpy as np
import tqdm
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.signal import convolve2d
from scipy.stats import pearsonr
import networkx as nx
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42
matplotlib.rcParams["font.size"] = "10"
matplotlib.rcParams["font.weight"] = 100
matplotlib.rcParams["axes.linewidth"] = 2
matplotlib.rcParams["axes.edgecolor"] = "#000000"
from colour import Color
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
import scanpy as sc
import dotplot
from svgpathtools import parse_path
from ..util._decorators import ignore_warning
from ._utils import get_spatial_image
from .utils._fit_curve import fit_p_value, fit_curve, valid_curves, normalize_func
from ..preprocess._preprocess import spatial_distance_cluster, calculate_dot_df
from ..util._classes import AnnDataSM, AnnDataST, AnnDataJointSMST
[docs]def make_colormap(
colors: list,
show_palette: bool = False,
):
"""
Make a colormap from a list of colors.
:param colors: list of colors to use in the colormap.
:param show_palette: whether to display the colormap as a palette, default is False.
:return: LinearSegmentedColormap object.
"""
color_ramp = LinearSegmentedColormap.from_list(
"my_list", [Color(c1).rgb for c1 in colors]
)
if show_palette:
plt.figure(figsize=(15, 3))
plt.imshow(
[list(np.arange(0, len(colors), 0.1))],
interpolation="nearest",
origin="lower",
cmap=color_ramp,
)
plt.xticks([])
plt.yticks([])
return color_ramp
def showionimage(p, mz_value=800, tol=200, z=1):
im = pyimzml.ImzMLParser.getionimage(p, mz_value, tol)
plt.imshow(im).set_interpolation("nearest")
plt.colorbar()
plt.show()
[docs]def create_fig(
figsize: tuple = (8, 4),
):
"""
Create a figure with the specified size and axis properties.
:param figsize: tuple specifying the size of the figure, default is (8, 4).
:return: figure and axis objects.
"""
fig, ax = plt.subplots()
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")
# ax.spines['bottom'].set_color('none')
# ax.spines['left'].set_color('none')
for line in ax.yaxis.get_ticklines():
line.set_markersize(5)
line.set_color("#585958")
line.set_markeredgewidth(0.5)
for line in ax.xaxis.get_ticklines():
line.set_markersize(5)
line.set_markeredgewidth(0.5)
line.set_color("#585958")
ax.set_xbound(0, 10)
ax.set_ybound(0, 10)
fig.set_size_inches(figsize)
return fig, ax
[docs]def create_subplots(nrow: int, ncol: int, figsize: tuple = (8, 4)):
"""
Create a figure with the specified size and axis properties.
:param nrow: number of rows in the subplot.
:param ncol: number of columns in the subplot.
:param figsize: tuple specifying the size of the figure, default is (8, 4).
:return: figure and axis objects.
"""
fig, axes = plt.subplots(nrow, ncol)
for ax in axes.flatten():
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")
for line in ax.yaxis.get_ticklines():
line.set_markersize(5)
line.set_color("#585958")
line.set_markeredgewidth(0.5)
for line in ax.xaxis.get_ticklines():
line.set_markersize(5)
line.set_markeredgewidth(0.5)
line.set_color("#585958")
fig.set_size_inches(figsize)
return fig, axes
[docs]def plot_spot_sm_st(
adata_SM: AnnDataSM,
adata_ST: AnnDataST,
SM_spatial_key: str = "spatial",
ST_spatial_key: str = "spatial",
stacked: bool = False,
ST_color: str = "#C7C8CC",
SM_color: str = "#C499BA",
s: int = 10,
**kwargs,
):
"""
Plot the spatial distribution of spots from Spatial Transcriptomics and Spatial Metabolomics data.
:param adata_SM: AnnData object containing the Spatial Metabolomics data.
:param adata_ST: AnnData object containing the Spatial Transcriptomics data.
:param SM_spatial_key: key in adata_SM.obsm where the spatial coordinates are stored, default is "spatial".
:param ST_spatial_key: key in adata_ST.obsm where the spatial coordinates are stored, default is "spatial".
:param stacked: whether to plot the data in a single plot or side-by-side, default is False.
:param ST_color: color to use for the Spatial Transcriptomics data, default is "#C7C8CC".
:param SM_color: color to use for the Spatial Metabolomics data, default is "#C499BA".
:param s: size of the scatter points, default is 10.
:param kwargs: additional keyword arguments to pass to the scatter function.
:return: None.
"""
if stacked:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.scatter(
adata_ST.obsm[ST_spatial_key][:, 0],
adata_ST.obsm[ST_spatial_key][:, 1],
c=ST_color,
s=s,
**kwargs,
)
ax.scatter(
adata_SM.obsm[SM_spatial_key][:, 0],
adata_SM.obsm[SM_spatial_key][:, 1],
c=SM_color,
s=s,
**kwargs,
)
# add legend
ax.legend(["Spatial Transcriptomics", "Spatial Metabolomics"])
else:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].scatter(
adata_ST.obsm[ST_spatial_key][:, 0],
adata_ST.obsm[ST_spatial_key][:, 1],
c=ST_color,
s=10,
**kwargs,
)
# add label tittle
ax[0].set_title("Spatial Transcriptomics")
ax[1].scatter(
adata_SM.obsm[SM_spatial_key][:, 0],
adata_SM.obsm[SM_spatial_key][:, 1],
c=SM_color,
s=10,
**kwargs,
)
# add label tittle
ax[1].set_title("Spatial Metabolomics")
[docs]def plot_newdot_sm_st(
new_dot_in_df: pd.DataFrame,
adata_SM: AnnDataSM,
adata_ST: AnnDataST,
ST_spatial_key: str = "spatial",
SM_spatial_key: str = "spatial",
ST_color: str = "#C7C8CC",
SM_color: str = "#C499BA",
new_dot_color: str = "#FF0000",
):
"""
Plot the spatial distribution of new spots in the Spatial Metabolomics and Spatial Transcriptomics data.
:param new_dot_in_df: DataFrame containing the spatial coordinates of the new spots.
:param adata_SM: AnnData object containing the Spatial Metabolomics data.
:param adata_ST: AnnData object containing the Spatial Transcriptomics data.
:param SM_spatial_key: key in adata_SM.obsm where the spatial coordinates are stored, default is "spatial".
:param ST_spatial_key: key in adata_ST.obsm where the spatial coordinates are stored, default is "spatial".
:param ST_color: color to use for the Spatial Transcriptomics data, default is "#C7C8CC".
:param SM_color: color to use for the Spatial Metabolomics data, default is "#C499BA".
:param new_dot_color: color to use for the new spots, default is "#FF0000".
:return: None.
"""
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.scatter(
new_dot_in_df["x_coord"],
new_dot_in_df["y_coord"],
c=new_dot_color,
s=3,
)
ax.scatter(
adata_ST.obsm[ST_spatial_key][:, 0],
adata_ST.obsm[ST_spatial_key][:, 1],
c=ST_color,
s=3,
)
ax.scatter(
adata_SM.obsm[SM_spatial_key][:, 0],
adata_SM.obsm[SM_spatial_key][:, 1],
c=SM_color,
s=3,
)
# add legend
ax.legend(["New Spots", "Spatial Transcriptomics", "Spatial Metabolomics"])
[docs]def plot_markerfeature(
adata: AnnData,
groupby: str,
palette: dict,
marker_feature_list: list,
figsize: tuple = (10, 4),
logfoldchanges_max: int = 10,
logfoldchanges_min: int = -2,
uns_key: str = "rank_genes_groups",
save_path: str = None,
):
"""
Create a scatter plot of the marker features in the AnnData object.
:param adata: AnnData object containing the data.
:param groupby: key in adata.obs to use for grouping the data.
:param palette: dictionary containing the colors to use for each group.
:param marker_feature_list: list of marker features to highlight.
:param figsize: tuple specifying the size of the figure, default is (10, 4).
:param logfoldchanges_max: maximum logfoldchanges value to consider, default is 10.
:param logfoldchanges_min: minimum logfoldchanges value to consider, default is -2.
:param uns_key: key in adata.uns where the data is stored, default is 'rank_genes_groups'.
:param save_path: path to save the figure, default is None.
:return: None.
"""
# Initialize figure and axis
fig, ax = plt.subplots(figsize=figsize)
# Calculate fraction of features for each group
group_frac_dict = {}
for size_obs_group in np.unique(adata.obs[groupby]):
_adata = adata[adata.obs[groupby] == size_obs_group]
frac = dict(
zip(
_adata.var.index,
np.array((_adata.X > 0).sum(0) / _adata.shape[0]).flatten(),
)
)
group_frac_dict[size_obs_group] = frac
# Prepare result DataFrame
result = []
columns = list(pd.DataFrame(adata.uns[uns_key]["logfoldchanges"]).columns)
for i, j, k in zip(
pd.DataFrame(adata.uns[uns_key]["logfoldchanges"]).to_numpy(),
pd.DataFrame(adata.uns[uns_key]["names"]).to_numpy(),
pd.DataFrame(adata.uns[uns_key]["pvals_adj"]).to_numpy(),
):
for e, (m, n, p) in enumerate(zip(i, j, -np.log10(k))):
result.append((columns[e], m, n, p))
result = pd.DataFrame(
result, columns=[groupby, "logfoldchanges", "gene_name", "pvals_adj"]
)
result = result.sort_values(groupby)
result["size"] = list(
map(
lambda x: group_frac_dict[x[0]][x[1]],
zip(result[groupby], result["gene_name"]),
)
)
# Make sure the groupby column uses pandas categorical type and add jitter
result[groupby] = result[groupby].astype("category")
result[groupby + "_codes"] = result[groupby].cat.codes + np.random.uniform(
-0.3, 0.3, len(result)
)
# Create scatterplots for different conditions
# Positive logfoldchanges, low absolute logfoldchanges, low pvals_adj
_data = result[
(result["logfoldchanges"] > 0)
& (np.abs(result["logfoldchanges"]) < logfoldchanges_max)
& (result["logfoldchanges"] > logfoldchanges_min)
& (
(np.abs(result["logfoldchanges"]) < 0.25)
| (np.abs(result["pvals_adj"]) <= 2)
)
]
sns.scatterplot(
data=_data,
x=groupby + "_codes",
y="logfoldchanges",
ax=ax,
marker="o",
size="size",
sizes=(0, 12),
color="#D7D7D7",
zorder=-1,
legend=False,
)
# Positive logfoldchanges, moderate absolute logfoldchanges, high pvals_adj
_data = result[
(result["logfoldchanges"] > 0)
& (result["logfoldchanges"] < logfoldchanges_max)
& (result["logfoldchanges"] > logfoldchanges_min)
& (np.abs(result["logfoldchanges"]) >= 0.25)
& (np.abs(result["pvals_adj"]) > 2)
]
for subtype in np.unique(_data[groupby]):
st = sns.scatterplot(
data=_data[_data[groupby] == subtype],
x=groupby + "_codes",
y="logfoldchanges",
ax=ax,
marker="o",
size="size",
sizes=(0, 12),
color=palette[subtype],
zorder=-1,
legend=False,
)
# Negative logfoldchanges, low absolute logfoldchanges, low pvals_adj
_data = result[
(result["logfoldchanges"] < 0)
& (result["logfoldchanges"] < logfoldchanges_max)
& (result["logfoldchanges"] > logfoldchanges_min)
& (
(np.abs(result["logfoldchanges"]) < 0.25)
| (np.abs(result["pvals_adj"]) <= 2)
)
]
sns.scatterplot(
data=_data,
x=groupby + "_codes",
y="logfoldchanges",
ax=ax,
marker="o",
size="size",
sizes=(0, 12),
color="#D7D7D7",
zorder=-1,
legend=False,
)
# Now you can safely use `st.get_children()` to retrieve the plot children
all_collections = list(
filter(
lambda x: isinstance(x, matplotlib.collections.PathCollection),
st.get_children(),
)
)
all_text = []
# Loop through each subtype and gene list and logfc is small logfoldchanges_max
for e, (logfc_arr, name_arr, pval_arr, gene_list) in enumerate(
zip(
pd.DataFrame(adata.uns[uns_key]["logfoldchanges"]).to_numpy().T,
pd.DataFrame(adata.uns[uns_key]["names"]).to_numpy().T,
pd.DataFrame(adata.uns[uns_key]["pvals_adj"]).to_numpy().T,
(
marker_feature_list
if marker_feature_list is not None
else [[] for _ in range(len(name_arr))]
), # Handle empty list case
)
):
# Loop through each gene name and its corresponding data
for _, (logfc, gene, pval) in enumerate(zip(logfc_arr, name_arr, pval_arr)):
# Check if the gene is in the list for the current subtype
if np.abs(logfc) < logfoldchanges_max:
if gene in gene_list:
# Find the coordinates in the result DataFrame
coords = result.loc[
(result[groupby].cat.codes == e) & (result["gene_name"] == gene),
groupby + "_codes",
].iloc[0]
x, y = coords, logfc
# Add a scatter point
c_ls = list(np.unique(_data[groupby]))
ax.scatter(
x,
y,
lw=0.3,
edgecolor="black",
c=palette[c_ls[e]],
s=group_frac_dict[c_ls[e]][gene] * 12,
)
# Annotate the gene name
all_text.append(
ax.annotate(
xy=(x, y),
text=gene,
size=8,
xytext=(1, 40),
textcoords="offset points",
ha="center",
arrowprops=dict(
arrowstyle="-",
mutation_scale=0.005,
color="black",
lw=0.5,
ls="-",
),
)
)
# Modify the x-axis tick labels to show the group names
x_tick_positions = range(len(c_ls))
# Set the x-tick positions and labels explicitly
ax.set_xticks(x_tick_positions)
ax.set_xticklabels(c_ls, rotation=45, ha="right")
# Display the plot
plt.show()
# Save the figure if save_path is provided
if save_path:
fig.savefig(save_path)
[docs]def plot_corrcoef_stsm_inall(
adata: AnnDataJointSMST,
row_cluster: bool = True,
col_cluster: bool = True,
figsize: tuple = (10, 10),
):
"""
Create a clustermap of the correlation coefficients between genes and metabolites.
:param adata: AnnDataJointSMST object containing the ST and SM data.
:param row_cluster: whether to cluster the rows, default is True.
:param col_cluster: whether to cluster the columns, default is True.
:param figsize: tuple specifying the size of the figure, default is (10, 10).
:return: None.
"""
gene_ls = list(np.unique(adata.uns["corrcoef_stsm_inall_top"]["gene"]))
metabolite_ls = list(np.unique(adata.uns["corrcoef_stsm_inall_top"]["metabolite"]))
corr_mat = np.zeros((len(gene_ls), len(metabolite_ls)))
pbar = tqdm.tqdm(total=len(gene_ls) * len(metabolite_ls))
for i in gene_ls:
for j in metabolite_ls:
r = adata.uns["corrcoef_stsm_inall"].loc[i, j]
if isinstance(r, np.float64):
corr_mat[gene_ls.index(i), metabolite_ls.index(j)] = r
elif len(r) > 0:
corr_mat[gene_ls.index(i), metabolite_ls.index(j)] = r[0]
pbar.update(1)
pbar.close()
palette = [
"#6b9080",
"#a4c3b2",
"#cce3de",
"#eaf4f4",
"#f6fff8",
"#ffe5d9",
"#ffcad4",
"#f4acb7",
"#9d8189",
]
sns.clustermap(
pd.DataFrame(corr_mat, index=gene_ls, columns=metabolite_ls),
cmap=make_colormap(palette),
fmt=".1f",
row_cluster=row_cluster,
col_cluster=col_cluster,
figsize=figsize,
)
[docs]def plot_corrcoef_stsm_ingroup(
adata: AnnDataJointSMST,
row_cluster: bool = True,
cluster: str = "cluster_0",
col_cluster: bool = True,
figsize: tuple = (10, 10),
):
"""
Create a clustermap of the correlation coefficients between genes and metabolites in a specific cluster.
:param adata: AnnDataJointSMST object containing the ST and SM data.
:param row_cluster: whether to cluster the rows, default is True.
:param cluster: cluster to consider, default is 'cluster_0'.
:param col_cluster: whether to cluster the columns, default is True.
:param figsize: tuple specifying the size of the figure, default is (10, 10).
:return: None.
"""
gene_ls = list(np.unique(adata.uns["corrcoef_stsm_ingroup_top"]["gene"]))
tmp_df = adata.uns["corrcoef_stsm_ingroup_top"]
metabolite_ls = list(
np.unique(tmp_df["metabolite"][tmp_df.cluster == cluster].values)
)
corr_mat = np.zeros((len(gene_ls), len(metabolite_ls)))
pbar = tqdm.tqdm(total=len(gene_ls) * len(metabolite_ls))
corrcoef_stsm_inall_mat = adata.uns["corrcoef_stsm_ingroup"][cluster]["data"]
index_ls = list(adata.uns["corrcoef_stsm_ingroup"][cluster]["gene"])
colnames_ls = list(adata.uns["corrcoef_stsm_ingroup"][cluster]["metabolite"])
corrcoef_stsm_inall_df = pd.DataFrame(
corrcoef_stsm_inall_mat, index=index_ls, columns=colnames_ls
)
for i in gene_ls:
for j in metabolite_ls:
r = corrcoef_stsm_inall_df.loc[i, j]
corr_mat[gene_ls.index(i), metabolite_ls.index(j)] = r
pbar.update(1)
pbar.close()
palette = [
"#6b9080",
"#a4c3b2",
"#cce3de",
"#eaf4f4",
"#f6fff8",
"#ffe5d9",
"#ffcad4",
"#f4acb7",
"#9d8189",
]
sns.clustermap(
pd.DataFrame(corr_mat, index=gene_ls, columns=metabolite_ls),
cmap=make_colormap(palette),
fmt=".1f",
row_cluster=row_cluster,
col_cluster=col_cluster,
figsize=figsize,
)
def _hex_to_rgb(value):
value = value.lstrip("#")
lv = len(value)
return tuple(int(value[i : i + lv // 3], 16) / 255 for i in range(0, lv, lv // 3))
def _get_valid_cell_type(adata, min_spots: int):
all_prediction_types = np.unique(adata.obs["celltype_predict"])
valid_spots = []
for i in all_prediction_types:
num_spots = np.sum(adata.obs["celltype_predict"] == i)
if num_spots >= min_spots:
valid_spots.append(i)
return valid_spots
[docs]def plot_spatial_deconvolution(
adata: sc.AnnData,
palettes: dict = None,
min_spots: int = 10,
img_alpha: float = 0.5,
show_color_bar: bool = True,
s: float = 10,
vmin: float = 0.5,
show: bool = False,
valid_cell_types: list = None,
ax=None,
marker="H",
use_pie_chart: bool = False,
key_celltype_predict_proportions: str = "celltype_predict_proportions",
key_celltype_predict: str = "celltype_predict",
):
"""
Create a spatial plot of the cell type predictions.
:param adata: AnnData object containing the data.
:param palettes: dictionary containing the colors to use for each group, default is None.
:param min_spots: minimum number of spots to consider, default is 10.
:param img_alpha: alpha value for the image, default is 0.5.
:param show_color_bar: whether to show the color bar, default is True.
:param s: size of the spots, default is 10.
:param vmin: minimum value for the spots, default is 0.5.
:param show: whether to show the plot, default is False.
:param valid_cell_types: list of valid cell types, default is None.
:param ax: axis to use for the plot, default is None.
:param marker: marker to use for the spots, default is 'H'.
:param use_pie_chart: whether to use a pie chart, default is False.
:param key_celltype_predict_proportions: key in adata.obs containing the cell type proportions, default is 'celltype_predict_proportions'.
:param key_celltype_predict: key in adata.obs containing the cell type predictions, default is 'celltype_predict'.
:return: None.
"""
spatial_coord = adata.obsm["spatial"].copy()
# fig, (ax1,ax2) = plt.subplots(1,2,figsize=(5,5))
# fig.set_size_inches(10, 5)
if palettes is None:
palettes = sc.pl.palettes.godsnot_102
if ax is None:
fig, ax1 = plt.subplots(figsize=(5, 5))
else:
ax1 = ax
img, scale = get_spatial_image(adata)
ax1.imshow(img, alpha=img_alpha)
if valid_cell_types is None:
valid_cell_types = _get_valid_cell_type(adata, min_spots)
if isinstance(list(palettes.values())[0], str):
palettes = {k: _hex_to_rgb(v) for k, v in palettes.items()}
if use_pie_chart:
c = list(adata.obs.columns).index(key_celltype_predict_proportions)
for e, (x, y) in enumerate(
zip(spatial_coord[:, 0] * scale, spatial_coord[:, 1] * scale)
):
ct = dict(
zip(
list(palettes.keys()),
adata.obs.iloc[e].loc[list(palettes.keys())].to_numpy().flatten(),
)
)
ax1.pie(
ct.values(),
radius=s * scale,
colors=list(palettes.values()),
center=(x, y),
)
else:
if len(valid_cell_types) > 1:
for i in valid_cell_types:
alphas = adata.obs.loc[
adata.obs[key_celltype_predict] == i,
key_celltype_predict_proportions,
]
alphas = (alphas - alphas.min()) / (
alphas.max() - alphas.min()
) * 0.5 + vmin
ax1.scatter(
spatial_coord[adata.obs[key_celltype_predict] == i, 0] * scale,
spatial_coord[adata.obs[key_celltype_predict] == i, 1] * scale,
c=list(map(lambda x: tuple(list(palettes[i]) + [x]), alphas)),
s=s * scale,
lw=0,
marker=marker,
)
else:
alphas = adata.obs.loc[
adata.obs[key_celltype_predict] == i,
key_celltype_predict_proportions,
]
plt.grid(False)
ax1.axis("off")
if show_color_bar:
sc.pl._tools.scatterplots._add_categorical_legend(
ax1,
pd.Categorical(valid_cell_types),
palette=palettes,
legend_loc="right margin",
legend_fontweight="bold",
legend_fontsize=None,
legend_fontoutline=None,
multi_panel=False,
na_color="F7F7F7",
na_in_legend=False,
scatter_array=False,
)
if show:
plt.show()
return fig, ax1
[docs]@ignore_warning(level="ignore")
def plot_gene_corrcoef_sm_ingroup(
adata: AnnDataJointSMST,
genename: str,
groupby: str = "leiden",
use_raw: bool = True,
ntop: int = 5,
):
"""
Create a dotplot of the correlation coefficients between genes and metabolites in a specific cluster.
:param adata: AnnDataJointSMST object containing the ST and SM data.
:param genename: name of the gene.
:param groupby: key in adata.obs to use for grouping the data, default is 'leiden'.
:param use_raw: whether to use the raw data, default is True.
:param ntop: number of top genes to consider, default is 5.
:return: None.
"""
if use_raw == True:
_adata = adata.raw.to_adata()
elif use_raw == False:
_adata = adata
results_df = pd.DataFrame()
l1 = adata.obs[groupby].unique()
l2 = list(_adata.var.index[_adata.var["type"] == "SM"])
corr_mat = np.zeros((len(l1), len(l2)))
corr_mat = pd.DataFrame(corr_mat, index=l1, columns=l2)
for cluster in _adata.obs[groupby].unique():
ST_tmp = _adata.X.toarray()[
_adata.obs[groupby] == cluster, _adata.var.index == genename
]
row_mask = _adata.obs[groupby] == cluster
SM_X = _adata.X.toarray()[row_mask][:, _adata.var["type"] == "SM"]
correlation_scores = []
for i in range(SM_X.shape[1]):
correlation_score = np.corrcoef(ST_tmp.squeeze(), SM_X[:, i])[0, 1]
corr_mat.loc[cluster, l2[i]] = correlation_score
correlation_scores.append(correlation_score)
metabolites = _adata.var.index[_adata.var["type"] == "SM"]
df = pd.DataFrame(
{"metabolite": metabolites, "correlation_score": correlation_scores}
)
df = df.sort_values(by="correlation_score", ascending=False)
results_tmp_df = df.head(ntop)
results_tmp_df["gene"] = genename
results_tmp_df["cluster"] = cluster
results_df = results_df.append(results_tmp_df)
metabolite_top_ls = list(results_df["metabolite"].unique())
results_out_df = corr_mat.loc[:, metabolite_top_ls]
results_out_df["cluster"] = results_out_df.index
results_out_long_df = pd.melt(
results_out_df,
id_vars=["cluster"],
var_name=["metabolite"],
value_name="correlation_score",
)
new_keys = {
"item_key": "cluster",
"group_key": "metabolite",
"sizes_key": "correlation_score",
}
dp = dotplot.DotPlot.parse_from_tidy_data(results_out_long_df, **new_keys)
sct = dp.plot(size_factor=100)
[docs]@ignore_warning(level="ignore")
def plot_volcano_corrcoef_gene(
adata: AnnDataJointSMST,
metabolite: str,
use_raw: bool = True,
threshold: float = 0.25,
nonmarker_size: int = 8,
marker_size: int = 20,
marker_alpha: float = 1,
color_threshold: float = 0.25,
color_above: str = "#D2649A",
color_below: str = "#40A578",
color_above_font: str = "#D2649A",
color_below_font: str = "#40A578",
fontsize: int = 8,
color_neutral: str = "grey",
figsize: tuple = (6, 5),
title: str = "Volcano Plot for Correlation Coefficients",
show: bool = True,
):
"""
Create a volcano plot of the correlation coefficients between genes and metabolites.
:param adata: AnnDataJointSMST object containing the ST and SM data.
:param metabolite: name of the metabolite.
:param use_raw: whether to use the raw data, default is True.
:param threshold: threshold for the correlation coefficient, default is 0.25.
:param nonmarker_size: size of the non-marker, default is 8.
:param marker_size: size of the marker, default is 20.
:param marker_alpha: alpha value for the marker, default is 1.
:param color_threshold: threshold for the color, default is 0.25.
:param color_above: color for values above the threshold, default is '#D2649A'.
:param color_below: color for values below the threshold, default is '#40A578'.
:param color_above_font: color for values above the threshold, default is '#D2649A'.
:param color_below_font: color for values below the threshold, default is '#40A578'.
:param fontsize: font size for the labels, default is 8.
:param color_neutral: color for neutral values, default is 'grey'.
:param figsize: tuple specifying the size of the figure, default is (6, 5).
:param title: title of the plot, default is 'Volcano Plot for Correlation Coefficients'.
:param show: whether to show the plot, default is True.
:return: None.
"""
if use_raw:
_adata = adata.raw.to_adata()
else:
_adata = adata
# Calculate correlation score and p-values
SM_X = _adata.X.toarray()[:, _adata.var.index == metabolite]
ST_X = _adata.X.toarray()[:, _adata.var["type"] == "ST"]
correlation_scores = []
p_values = []
for i in range(ST_X.shape[1]):
correlation_score, p_value = pearsonr(SM_X.squeeze(), ST_X[:, i])
correlation_scores.append(correlation_score)
p_values.append(p_value)
genes = _adata.var.index[_adata.var["type"] == "ST"]
df = pd.DataFrame(
{"gene": genes, "correlation_score": correlation_scores, "p_value": p_values}
)
df = df.sort_values(by="correlation_score", ascending=False)
results_df = df
# Plot volcano plot
plt.figure(figsize=figsize)
colors = np.where(
results_df["correlation_score"] > color_threshold,
color_above,
np.where(
results_df["correlation_score"] < -color_threshold,
color_below,
color_neutral,
),
)
sizes = np.where(
np.abs(results_df["correlation_score"]) > threshold, marker_size, nonmarker_size
) # Adjust marker sizes
plt.scatter(
results_df["correlation_score"],
-np.log10(results_df["p_value"]),
c=colors,
s=sizes,
alpha=marker_alpha,
)
plt.xlabel("Correlation coefficient (r)")
plt.ylabel("-log10(P-value)")
plt.title(title)
plt.axvline(x=threshold, color=color_above, linestyle="--")
plt.axvline(x=-threshold, color=color_below, linestyle="--")
# Add labels to significant points
for i, row in results_df.iterrows():
if np.abs(row["correlation_score"]) > threshold:
if row["correlation_score"] > threshold:
plt.text(
row["correlation_score"] + 0.02,
-np.log10(row["p_value"]),
row["gene"],
fontsize=fontsize,
color=color_above_font,
ha="left",
va="center",
)
else:
plt.text(
row["correlation_score"] - 0.02,
-np.log10(row["p_value"]),
row["gene"],
fontsize=fontsize,
color=color_below_font,
ha="right",
va="center",
)
if show == True:
plt.show()
[docs]def plot_features_trajectory(
adata: AnnData,
features: list,
bins: int=100,
palette: Union[dict, list]=None,
figsize: tuple=(10, 5),
scale: bool=False,
save_path: str=None,
):
"""
Create a line plot of the mean values of features along a trajectory.
:param adata: AnnData object containing the data.
:param features: list of features to plot.
:param bins: number of bins to use for grouping the data, default is 100.
:param palette: dictionary or list of colors to use for the features, default is None.
:param figsize: tuple specifying the size of the figure, default is (10, 5).
:param scale: whether to scale the values, default is False.
:param save_path: path to save the plot, default is None.
:return: None.
"""
# Define default palette as a list if none is provided
default_palette = [
"#1f77b4",
"#ff7f0e",
"#279e68",
"#d62728",
"#aa40fc",
"#8c564b",
"#e377c2",
"#b5bd61",
"#17becf",
"#aec7e8",
"#ffbb78",
"#98df8a",
"#ff9896",
"#c5b0d5",
"#c49c94",
"#f7b6d2",
"#dbdb8d",
"#9edae5",
"#ad494a",
"#8c6d31",
]
# Use the default palette if none is provided
if palette is None:
palette = default_palette
# Verify each feature exists in the adata object
for feature in features:
if feature not in adata.var.index:
raise ValueError(f"Feature '{feature}' not found in adata.")
plt.figure(figsize=figsize)
# Iterate over each feature and plot
for idx, feature in enumerate(features):
# Get feature data and locations
feature_index = list(adata.var.index).index(feature)
arr = (
adata.layers["normalized"][
adata.uns["trajectory"]["trajectory_1"]["indices"], feature_index
]
.toarray()
.flatten()
)
locations = np.array(adata.uns["trajectory"]["trajectory_1"]["locations"])
# Filter out zero values
locations = locations[arr != 0]
arr = arr[arr != 0]
# Apply 0-1 scaling if scale parameter is True
if scale:
arr = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
# Create a DataFrame with locations and values
df = pd.DataFrame({"locations": locations, "value": arr})
# Group data into bins and calculate mean and std
df["bin"] = pd.cut(df["locations"], bins)
grouped_data = df.groupby("bin")["value"].agg(["mean", "std"])
# Calculate midpoints of bins
x = [(interval.left + interval.right) / 2 for interval in grouped_data.index]
mean_values = grouped_data["mean"]
std_values = grouped_data["std"]
# Determine color from palette
if isinstance(palette, dict):
color = palette.get(feature, default_palette[idx % len(default_palette)])
else:
color = palette[idx % len(palette)]
# Plot line and shaded region
sns.lineplot(x=x, y=mean_values, label=feature, color=color)
plt.fill_between(
x,
mean_values - std_values,
mean_values + std_values,
color=color,
alpha=0.2,
)
# Set axis labels and show legend
plt.xlabel("Location")
plt.ylabel("Mean Value")
plt.legend()
if save_path:
plt.savefig(save_path)
# Show the plot
plt.show()
def _unique_preserve_order(arr):
unique_values, indices = np.unique(arr, return_index=True)
sorted_indices = np.sort(indices)
unique_elements_preserved_order = arr[sorted_indices]
return unique_elements_preserved_order
def _smooth_rows(X, window_size):
# Define the kernel for smoothing
kernel = np.ones((1, window_size)) / window_size
smoothed_X = convolve2d(X, kernel, mode="same", boundary="wrap")
return smoothed_X
[docs]def plot_clustermap_with_smoothing(
adata: AnnData,
window_size: int=200,
cmap: str ="vlag",
feature_top: int=10,
key: str="rank_genes_groups",
save_path: str=None,
figsize: tuple=(10, 10),
return_data: bool=False,
**kwargs,
):
"""
Create a clustermap of the top features with smoothing applied.
:param adata: AnnData object containing the data.
:param window_size: size of the window for smoothing, default is 200.
:param cmap: colormap to use for the clustermap, default is 'vlag'.
:param feature_top: number of top features to consider, default is 10.
:param key: key in adata.uns to use for the features, default is 'rank_genes_groups'.
:param save_path: path to save the plot, default is None.
:param figsize: tuple specifying the size of the figure, default is (10, 10).
:param return_data: whether to return the data, default is False.
:return: None.
"""
varnames = _unique_preserve_order(
pd.DataFrame(adata.uns[key]["names"]).head(feature_top).to_numpy().T.flatten()
)
X = adata[:, varnames].layers["normalized"].toarray()
X = _smooth_rows(X.T, window_size)
# Generate a clustermap with specified settings
g = sns.clustermap(
pd.DataFrame(X, index=varnames),
row_cluster=False,
col_cluster=False,
standard_scale="var", # Standardize by columns (1)
cmap=cmap,
**kwargs,
)
g.ax_heatmap.set_xticklabels([])
plt.gcf().set_size_inches(figsize)
if save_path:
plt.savefig(save_path)
plt.show()
# return dict format
if return_data:
data_dict = {"X": X, "varnames": varnames}
return data_dict
[docs]def plot_trajectory_with_arrows(
adata: AnnData,
path_key: str = "trajectory_1",
img_key: str = "scaledres",
color: str = "trajectory_1",
fig=None,
ax=None,
arrow_head_width: int = 15,
arrow_width: float = 0.05,
show: bool = False,
**kwargs,
):
"""
Create a spatial plot of the trajectory with arrows.
:param adata: AnnData object containing the data.
:param path_key: key in adata.uns containing the path, default is 'trajectory_1'.
:param img_key: key in adata.uns containing the image, default is 'scaledres'.
:param color: key in adata.obs containing the color, default is 'trajectory_1'.
:param fig: figure to use for the plot, default is None.
:param ax: axis to use for the plot, default is None.
:param arrow_head_width: width of the arrow head, default is 15.
:param arrow_width: width of the arrow, default is 0.05.
:param show: whether to show the plot, default is False.
:return: None.
"""
if fig is None or ax is None:
fig, ax = plt.subplots()
sc.pl.spatial(adata, color=color, ax=ax, show=show, img_key=img_key, **kwargs)
path = parse_path(adata.uns["trajectory"][path_key]["path"])
for i in range(len(path)):
start = path[i].start
end = path[i].end
ax.arrow(
start.real,
start.imag,
end.real - start.real,
end.imag - start.imag,
head_width=arrow_head_width,
width=arrow_width,
edgecolor="white",
)
[docs]def plot_network(
adata: AnnData,
groupby: str = "leiden",
spatial_key: str = "spatial",
use_raw: bool = False,
palette: dict = None,
edge_use_scale: bool = True,
node_use_scale: bool = True,
node_scale_factor: float = 10.0,
edge_weight_threshold: float = 0.1,
edge_scale_factor: float = 1.0,
seed: int = None,
top_n_neighbors: int = 5,
show_weight: bool = False,
show_labels: bool = True,
node_min_size: float = 1.0,
split_by: str = None,
return_data: bool = False,
show: bool = True,
iterations: int = 50,
):
"""
Create a network plot of the data. And the network plot is generated based on the spatial distance and spot size.
:param adata: AnnData object containing the data.
:param groupby: key in adata.obs to use for grouping the data, default is 'leiden'.
:param spatial_key: key in adata.obsm containing the spatial data, default is 'spatial'.
:param use_raw: whether to use the raw data, default is False.
:param palette: dictionary containing the colors to use for each group, default is None.
:param edge_use_scale: whether to use scale for the edge, default is True.
:param node_use_scale: whether to use scale for the node, default is True.
:param node_scale_factor: factor to scale the node, default is 10.0.
:param edge_weight_threshold: threshold for the edge weight, default is 0.1.
:param edge_scale_factor: factor to scale the edge, default is 1.0.
:param seed: seed to use for random number generation, default is None.
:param top_n_neighbors: number of top neighbors to consider, default is 5.
:param show_weight: whether to show the weight, default is False.
:param show_labels: whether to show the labels, default is True.
:param node_min_size: minimum size of the node, default is 1.0.
:param split_by: key in adata.obs to use for splitting the data, default is None.
:param return_data: whether to return the data, default is False.
:param show: whether to show the plot, default is True.
:param iterations: number of iterations to use for the layout, default is 50.
:return: None.
"""
np.random.seed(seed)
if split_by is None:
distance_df = spatial_distance_cluster(
adata, groupby=groupby, spatial_key=spatial_key, useraw=use_raw
)
dot_df = calculate_dot_df(
adata, groupby=groupby, spatial_key=spatial_key, use_raw=use_raw
)
else:
distance_total_df = pd.DataFrame()
dot_total_df = pd.DataFrame()
for i in adata.obs[split_by].unique():
adata_tmp = adata[adata.obs[split_by] == i]
distance_df_tmp = spatial_distance_cluster(
adata_tmp, groupby=groupby, spatial_key=spatial_key, use_raw=use_raw
)
dot_df_tmp = calculate_dot_df(
adata_tmp, groupby=groupby, spatial_key=spatial_key, use_raw=use_raw
)
distance_total_df = pd.concat([distance_total_df, distance_df_tmp])
dot_total_df = pd.concat([dot_total_df, dot_df_tmp])
distance_total_df["split_by"] = i
dot_total_df["split_by"] = i
# group by split_by and calculate mean
distance_total_df = distance_total_df[["from", "to", "distance"]]
dot_total_df = dot_total_df[["dot_name", "cell_number"]]
distance_df = distance_total_df.groupby(["from", "to"]).mean().reset_index()
dot_df = dot_total_df.groupby(["dot_name"]).mean().reset_index()
distance_df["weight"] = 1 / distance_df["distance"]
distance_df["scale_distance"] = (
distance_df["distance"] - distance_df["distance"].min()
) / (distance_df["distance"].max() - distance_df["distance"].min())
distance_df["scale_weight"] = (
distance_df["weight"] - distance_df["weight"].min()
) / (distance_df["weight"].max() - distance_df["weight"].min())
dot_df["dot_size"] = dot_df["cell_number"] / dot_df["cell_number"].sum()
dot_df["scale_dot_size"] = (dot_df["dot_size"] - dot_df["dot_size"].min()) / (
dot_df["dot_size"].max() - dot_df["dot_size"].min()
)
distance_df_unique = distance_df.copy()
distance_df_unique[["from", "to"]] = pd.DataFrame(
np.sort(distance_df_unique[["from", "to"]], axis=1),
index=distance_df_unique.index,
)
distance_df_unique.drop_duplicates(subset=["from", "to"], inplace=True)
G = nx.Graph()
for idx, row in distance_df_unique.iterrows():
if edge_use_scale:
weight = row["scale_weight"]
else:
weight = row["weight"]
if weight >= edge_weight_threshold:
G.add_edge(row["from"], row["to"], weight=weight)
node_df = dot_df.copy()
node_colors = {}
node_sizes = {}
if palette is None:
palette = sc.pl.palettes.default_102
if isinstance(palette, dict):
node_colors = {node: palette[node] for node in G.nodes}
else:
node_colors = dict(zip(G.nodes, palette))
if node_use_scale:
node_sizes = {
node: max(
node_df[node_df["dot_name"] == node]["scale_dot_size"].iloc[0]
* node_scale_factor,
node_min_size,
)
for node in G.nodes
}
else:
node_sizes = {
node: max(
node_df[node_df["dot_name"] == node]["dot_size"].iloc[0]
* node_scale_factor,
node_min_size,
)
for node in G.nodes
}
pos = nx.spring_layout(G, iterations=iterations)
# from fa2 import ForceAtlas2
# nx.to_scipy_sparse_matrix = nx.to_scipy_sparse_array
# forceatlas2 = ForceAtlas2()
# pos = forceatlas2.forceatlas2_networkx_layout(G, pos=None, iterations=50)
edge_widths = [G[u][v]["weight"] * edge_scale_factor for u, v in G.edges]
for node in G.nodes:
neighbors = list(G.neighbors(node))
if len(neighbors) > top_n_neighbors:
sorted_neighbors = sorted(neighbors, key=lambda x: G[node][x]["weight"])
for neighbor in sorted_neighbors[: len(sorted_neighbors) - top_n_neighbors]:
G.remove_edge(node, neighbor)
nx.draw(
G,
pos,
with_labels=show_labels,
node_color=list(node_colors.values()),
node_size=[node_sizes[n] for n in G.nodes],
width=edge_widths,
)
if show_weight:
labels = nx.get_edge_attributes(G, "weight")
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
if return_data:
return G, node_df, distance_df_unique
if show:
plt.show()