from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from anndata import AnnData
from typing import List, Literal
from matplotlib.axes import Axes
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pysankey
import scanpy
import seaborn as sns
from matplotlib import gridspec
from matplotlib.markers import MarkerStyle
[docs]
def remodeling_score(
remodeling_score: np.ndarray,
show: bool | None = None,
save: bool | str | None = None,
) -> List[plt.Axes] | None:
"""Visualise the distribution of protein *remodeling scores*.
The function creates a vertical layout containing a histogram (top) and a
compact box plot (bottom). It mirrors the visualisation published in
Hein *et al.* 2024 and serves as a quick overview of score
distribution and potential outliers.
Parameters
----------
remodeling_score
One-dimensional array of remodeling scores, typically generated by
:func:`~grassp.tl.remodeling_score`.
show
If ``True`` (default) the plot is shown and the function returns ``None``.
save
Path or boolean forwarded to :func:`scanpy.pl._utils.savefig_or_show`.
If ``True``, a PNG file named ``remodeling_score.png`` is written.
Returns
-------
When ``show`` is ``False``, the list ``[ax_hist, ax_box]`` is returned to
enable further customisation. Otherwise ``None``.
"""
# Create grid layout
gs = gridspec.GridSpec(2, 1, height_ratios=[5, 1])
# Histogram
ax0 = plt.subplot(gs[0])
sns.histplot(remodeling_score, ax=ax0, kde=False)
ax0.set(xlabel="")
# turn off x tick labels
ax0.set_xticklabels([])
# Boxplot
ax1 = plt.subplot(gs[1])
sns.boxplot(
x=remodeling_score,
ax=ax1,
flierprops=dict(
marker="o", markeredgecolor="orange", markerfacecolor="none", markersize=6
),
)
ax1.set(xlabel="Remodeling score")
axs = [ax0, ax1]
show = scanpy.settings.autoshow if show is None else show
scanpy.pl._utils.savefig_or_show("remodeling_score", show=show, save=save)
if show:
return None
return axs
remodeling_legend = [
mlines.Line2D(
[], [], color="black", marker="*", linestyle="None", label="remodeled_proteins"
),
mlines.Line2D(
[], [], color="grey", linestyle="-", linewidth=2, label="remodeling trajectory"
),
]
def _get_cluster_colors(data: AnnData, color_key: str = "leiden") -> np.ndarray[str, Any]:
"""Get cluster colors for a given color key.
Parameters
----------
data
Annotated data matrix
color_key
Key in data.obs to use for coloring
Returns
-------
np.ndarray
Array of colors for the given color key
"""
if f"{color_key}_colors" not in data.uns.keys():
scanpy.pl._utils._set_default_colors_for_categorical_obs(data, color_key)
return np.array(
[data.uns[f"{color_key}_colors"][x] for x in data.obs[color_key].cat.codes]
)
[docs]
def aligned_umap(
data: AnnData,
data2: AnnData,
highlight_hits: List[str] | np.ndarray[bool, Any] | None = None,
highlight_annotation_col: str | None = None,
aligned_umap_key: str = "X_aligned_umap",
data1_label: str = "data1",
data2_label: str = "data2",
color_by: Literal["perturbation", "cluster"] = "perturbation",
data1_color: str = "#C7E8F9",
data2_color: str = "#FFCCC2",
figsize: tuple[float, float] = (8.25, 6),
size: int = 80,
alpha: float = 0.4,
ax: plt.Axes | None = None,
show: bool | None = None,
save: bool | str | None = None,
) -> plt.Axes | None:
"""Side-by-side visualisation of aligned UMAP embeddings.
Given two datasets embedded into a common aligned UMAP space (see
:func:`~grassp.tl.aligned_umap`), the function plots both embeddings in a
single scatter plot and optionally draws **remodeling trajectories** for
highlighted proteins.
Parameters
----------
data, data2
Two :class:`~anndata.AnnData` objects with the same set/order of
observations and a shared key in ``.obsm`` (*aligned_umap_key*).
highlight_hits
Proteins to emphasise with a star marker and a line connecting their
positions between the two datasets. Can be a list of observation
names or a boolean mask.
highlight_annotation_col
Optional column in ``data.obs`` used for text labels next to the
highlighted points.
aligned_umap_key
Key in ``.obsm`` containing the aligned coordinates (default:
``"X_aligned_umap"``).
data1_label, data2_label
Legend labels for the two datasets.
color_by
If ``"perturbation"`` (default) both datasets are assigned uniform
colors (``data1_color``, ``data2_color``). If ``"cluster"`` cluster-
specific colors defined in ``data.obs[color_key]`` are used.
size, alpha, figsize, ax
Standard matplotlib styling parameters.
show, save
Forwarded to :func:`scanpy.pl._utils.savefig_or_show`.
Returns
-------
Returns the Axes object if ``show`` is ``False``.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
if color_by == "cluster":
data1_colors = _get_cluster_colors(data, data1_color)
data2_colors = _get_cluster_colors(data2, data2_color)
# Create cluster legend handles
cluster_handles = []
unique_clusters = data.obs[data1_color].cat.categories
cluster_colors = data.uns[f"{data1_color}_colors"]
for cluster, color in zip(unique_clusters, cluster_colors):
cluster_handles.append(
mlines.Line2D(
[],
[],
color=color,
marker="o",
linestyle="None",
markersize=8,
label=cluster,
)
)
else:
data1_colors = data1_color
data2_colors = data2_color
embedding1 = data.obsm[aligned_umap_key]
embedding2 = data2.obsm[aligned_umap_key]
# Plot the two embeddings as scatter plots
ax.scatter(
np.asarray(embedding1)[:, 0],
np.asarray(embedding1)[:, 1],
c=data1_colors,
s=size,
alpha=alpha,
label=data1_label,
marker=MarkerStyle("."),
linewidths=0,
edgecolor=None,
)
ax.scatter(
np.asarray(embedding2)[:, 0],
np.asarray(embedding2)[:, 1],
c=data2_colors,
s=size,
alpha=alpha,
label=data2_label,
marker=MarkerStyle("+"),
linewidths=1,
edgecolor=None,
)
if highlight_hits is not None:
embedding1_hits = np.asarray(embedding1)[highlight_hits]
embedding2_hits = np.asarray(embedding2)[highlight_hits]
# Plot trajectory lines
for i, (start, end) in enumerate(zip(embedding1_hits, embedding2_hits)):
# Draw line
ax.plot(
[start[0], end[0]],
[start[1], end[1]],
color="grey",
linewidth=0.7,
alpha=0.5,
)
# Draw marker at the end point
ax.scatter(
start[0],
start[1],
color="black",
s=30,
marker=MarkerStyle("*"),
edgecolor=None,
)
if highlight_annotation_col is not None:
# Add annotation
ax.annotate(
str(data.obs.loc[highlight_hits, highlight_annotation_col].iloc[i]),
(start[0], start[1]),
color="black",
fontsize=5,
bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5),
ha="right",
va="bottom",
xytext=(5, 5),
textcoords="offset points",
)
# Combine scatter plot legend with remodeling legend
handles, labels = ax.get_legend_handles_labels()
combined_handles = handles + remodeling_legend
# Add legends to the right of the plot
if color_by == "cluster":
# First legend for dataset and remodeling markers
ax.legend(handles=combined_handles, bbox_to_anchor=(1.15, 1), loc="upper left")
# Second legend for clusters
ax.legend(
handles=cluster_handles,
bbox_to_anchor=(1.15, 0.5),
loc="center left",
title="Clusters",
)
else:
# Single legend if not coloring by cluster
ax.legend(handles=combined_handles, bbox_to_anchor=(1.15, 1), loc="upper left")
# Adjust layout to prevent legend cutoff
plt.tight_layout()
ax.set_xlabel(f"{aligned_umap_key.replace('X_', '')}1")
ax.set_ylabel(f"{aligned_umap_key.replace('X_', '')}2")
ax.set_xticks([])
ax.set_yticks([])
show = scanpy.settings.autoshow if show is None else show
scanpy.pl._utils.savefig_or_show("aligned_umap", show=show, save=save)
if show:
return None
return ax
[docs]
def remodeling_sankey(
data: AnnData,
data2: AnnData,
cluster_key: str = "leiden",
ax: plt.Axes | None = None,
aspect: int = 20,
fontsize: int = 12,
figsize: tuple[float, float] = (10, 11),
show: bool | None = None,
save: bool | str | None = None,
) -> plt.Axes | None:
"""Sankey diagram of cluster re-assignment between two conditions.
Parameters
----------
data, data2
Two :class:`~anndata.AnnData` objects with identical observation
indices and a categorical *cluster_key* in ``.obs``.
cluster_key
Column name that encodes cluster memberships.
aspect, fontsize, figsize, ax
Styling options passed to :func:`pysankey.sankey` or matplotlib.
show, save
Behaviour identical to other plotting functions.
Returns
-------
Axes containing the Sankey plot if ``show`` is ``False``.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
# Check that the anndata objects are aligned
assert (data.obs_names == data2.obs_names).all()
pysankey.sankey(
left=data.obs[cluster_key],
right=data2.obs[cluster_key],
aspect=aspect,
# colorDict=colorDict,
fontsize=fontsize,
color_gradient=False,
# leftLabels=[
# "nucleus",
# "cytosol",
# "mitochondrion",
# "ER",
# "plasma memb. & actin",
# "endo-lysosome & trans-Golgi",
# "ERGIC/Golgi",
# "translation/RNA granules",
# "peroxisome",
# ],
# rightLabels=[
# "nucleus",
# "cytosol",
# "mitochondrion",
# "ER",
# "plasma memb. & actin",
# "endo-lysosome & trans-Golgi",
# "translation/RNA granules",
# "peroxisome",
# "COPI vesicle",
# ],
ax=ax,
)
show = scanpy.settings.autoshow if show is None else show
scanpy.pl._utils.savefig_or_show("remodeling_sankey", show=show, save=save)
if show:
return None
return ax
[docs]
def mr_plot(
data: AnnData,
mr_key: str = "mr_scores",
ax: plt.Axes = None,
m_cutoffs: list[float] = [2, 3, 4],
r_cutoffs: list[float] = [0.68, 0.81, 0.93],
highlight_hits: bool = True,
show: bool | None = None,
save: bool | str | None = None,
**kwargs,
) -> Axes | None:
"""MR-plot for simultaneous visualisation of *M* and *R* scores.
*M* corresponds to ``-log10(q-value)`` and reflects statistical
significance; *R* is a measure of consistencey of the re-localisation between replicates.
Parameters
----------
data
AnnData object after :func:`grassp.tl.mr_score`.
mr_key
Prefix of the keys written by ``mr_score`` (defaults to
``"mr_scores"`` resulting in ``mr_scores_M`` and ``mr_scores_R`` in
``.obs``).
m_cutoffs, r_cutoffs
Horizontal/vertical guideline positions (lenient → stringent).
highlight_hits
If ``True`` mark proteins passing the *lenient* thresholds in red.
ax, show, save, **kwargs
Standard matplotlib/scanpy plotting options.
Returns
-------
Returns the Axes object if ``show`` is ``False``.
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
try:
m_scores = data.obs[f"{mr_key}_M"]
r_scores = data.obs[f"{mr_key}_R"]
except KeyError:
raise ValueError(
f"MR scores not found in data.obs['{mr_key}_M/R'], run gr.tl.mr_score first"
)
# Plot data
ax.scatter(m_scores, r_scores, alpha=0.5, s=10, color="black", marker=".", **kwargs)
if highlight_hits:
hits = (m_scores >= m_cutoffs[0]) & (r_scores >= r_cutoffs[0])
ax.scatter(m_scores[hits], r_scores[hits], color="red", s=20, marker=".")
# Add cutoff lines
colors = ["gray", "darkgray", "lightgray"]
for m_cut, color in zip(m_cutoffs, colors):
ax.axvline(m_cut, color=color, linestyle="--", alpha=0.5)
for r_cut, color in zip(r_cutoffs, colors):
ax.axhline(r_cut, color=color, linestyle="--", alpha=0.5)
# Set x-axis limits
ax.set_xlim(0, np.max(m_scores) + 5)
# Set labels
ax.set_xlabel("M score (-log10 Q-value)")
ax.set_ylabel("R score (minimum correlation)")
ax.set_title("MR Plot")
show = scanpy.settings.autoshow if show is None else show
scanpy.pl._utils.savefig_or_show("remodeling_score", show=show, save=save)
if show:
return None
return ax