Source code for grassp.plotting.integration

from __future__ import annotations
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from anndata import AnnData
    from typing import List, Literal
    from matplotlib.axes import Axes


import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pysankey
import scanpy
import seaborn as sns

from matplotlib import gridspec
from matplotlib.markers import MarkerStyle


[docs] def remodeling_score( remodeling_score: np.ndarray, show: bool | None = None, save: bool | str | None = None, ) -> List[plt.Axes] | None: """Visualise the distribution of protein *remodeling scores*. The function creates a vertical layout containing a histogram (top) and a compact box plot (bottom). It mirrors the visualisation published in Hein *et al.* 2024 and serves as a quick overview of score distribution and potential outliers. Parameters ---------- remodeling_score One-dimensional array of remodeling scores, typically generated by :func:`~grassp.tl.remodeling_score`. show If ``True`` (default) the plot is shown and the function returns ``None``. save Path or boolean forwarded to :func:`scanpy.pl._utils.savefig_or_show`. If ``True``, a PNG file named ``remodeling_score.png`` is written. Returns ------- When ``show`` is ``False``, the list ``[ax_hist, ax_box]`` is returned to enable further customisation. Otherwise ``None``. """ # Create grid layout gs = gridspec.GridSpec(2, 1, height_ratios=[5, 1]) # Histogram ax0 = plt.subplot(gs[0]) sns.histplot(remodeling_score, ax=ax0, kde=False) ax0.set(xlabel="") # turn off x tick labels ax0.set_xticklabels([]) # Boxplot ax1 = plt.subplot(gs[1]) sns.boxplot( x=remodeling_score, ax=ax1, flierprops=dict( marker="o", markeredgecolor="orange", markerfacecolor="none", markersize=6 ), ) ax1.set(xlabel="Remodeling score") axs = [ax0, ax1] show = scanpy.settings.autoshow if show is None else show scanpy.pl._utils.savefig_or_show("remodeling_score", show=show, save=save) if show: return None return axs
remodeling_legend = [ mlines.Line2D( [], [], color="black", marker="*", linestyle="None", label="remodeled_proteins" ), mlines.Line2D( [], [], color="grey", linestyle="-", linewidth=2, label="remodeling trajectory" ), ] def _get_cluster_colors(data: AnnData, color_key: str = "leiden") -> np.ndarray[str, Any]: """Get cluster colors for a given color key. Parameters ---------- data Annotated data matrix color_key Key in data.obs to use for coloring Returns ------- np.ndarray Array of colors for the given color key """ if f"{color_key}_colors" not in data.uns.keys(): scanpy.pl._utils._set_default_colors_for_categorical_obs(data, color_key) return np.array( [data.uns[f"{color_key}_colors"][x] for x in data.obs[color_key].cat.codes] )
[docs] def aligned_umap( data: AnnData, data2: AnnData, highlight_hits: List[str] | np.ndarray[bool, Any] | None = None, highlight_annotation_col: str | None = None, aligned_umap_key: str = "X_aligned_umap", data1_label: str = "data1", data2_label: str = "data2", color_by: Literal["perturbation", "cluster"] = "perturbation", data1_color: str = "#C7E8F9", data2_color: str = "#FFCCC2", figsize: tuple[float, float] = (8.25, 6), size: int = 80, alpha: float = 0.4, ax: plt.Axes | None = None, show: bool | None = None, save: bool | str | None = None, ) -> plt.Axes | None: """Side-by-side visualisation of aligned UMAP embeddings. Given two datasets embedded into a common aligned UMAP space (see :func:`~grassp.tl.aligned_umap`), the function plots both embeddings in a single scatter plot and optionally draws **remodeling trajectories** for highlighted proteins. Parameters ---------- data, data2 Two :class:`~anndata.AnnData` objects with the same set/order of observations and a shared key in ``.obsm`` (*aligned_umap_key*). highlight_hits Proteins to emphasise with a star marker and a line connecting their positions between the two datasets. Can be a list of observation names or a boolean mask. highlight_annotation_col Optional column in ``data.obs`` used for text labels next to the highlighted points. aligned_umap_key Key in ``.obsm`` containing the aligned coordinates (default: ``"X_aligned_umap"``). data1_label, data2_label Legend labels for the two datasets. color_by If ``"perturbation"`` (default) both datasets are assigned uniform colors (``data1_color``, ``data2_color``). If ``"cluster"`` cluster- specific colors defined in ``data.obs[color_key]`` are used. size, alpha, figsize, ax Standard matplotlib styling parameters. show, save Forwarded to :func:`scanpy.pl._utils.savefig_or_show`. Returns ------- Returns the Axes object if ``show`` is ``False``. """ if ax is None: fig, ax = plt.subplots(figsize=figsize) if color_by == "cluster": data1_colors = _get_cluster_colors(data, data1_color) data2_colors = _get_cluster_colors(data2, data2_color) # Create cluster legend handles cluster_handles = [] unique_clusters = data.obs[data1_color].cat.categories cluster_colors = data.uns[f"{data1_color}_colors"] for cluster, color in zip(unique_clusters, cluster_colors): cluster_handles.append( mlines.Line2D( [], [], color=color, marker="o", linestyle="None", markersize=8, label=cluster, ) ) else: data1_colors = data1_color data2_colors = data2_color embedding1 = data.obsm[aligned_umap_key] embedding2 = data2.obsm[aligned_umap_key] # Plot the two embeddings as scatter plots ax.scatter( np.asarray(embedding1)[:, 0], np.asarray(embedding1)[:, 1], c=data1_colors, s=size, alpha=alpha, label=data1_label, marker=MarkerStyle("."), linewidths=0, edgecolor=None, ) ax.scatter( np.asarray(embedding2)[:, 0], np.asarray(embedding2)[:, 1], c=data2_colors, s=size, alpha=alpha, label=data2_label, marker=MarkerStyle("+"), linewidths=1, edgecolor=None, ) if highlight_hits is not None: embedding1_hits = np.asarray(embedding1)[highlight_hits] embedding2_hits = np.asarray(embedding2)[highlight_hits] # Plot trajectory lines for i, (start, end) in enumerate(zip(embedding1_hits, embedding2_hits)): # Draw line ax.plot( [start[0], end[0]], [start[1], end[1]], color="grey", linewidth=0.7, alpha=0.5, ) # Draw marker at the end point ax.scatter( start[0], start[1], color="black", s=30, marker=MarkerStyle("*"), edgecolor=None, ) if highlight_annotation_col is not None: # Add annotation ax.annotate( str(data.obs.loc[highlight_hits, highlight_annotation_col].iloc[i]), (start[0], start[1]), color="black", fontsize=5, bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5), ha="right", va="bottom", xytext=(5, 5), textcoords="offset points", ) # Combine scatter plot legend with remodeling legend handles, labels = ax.get_legend_handles_labels() combined_handles = handles + remodeling_legend # Add legends to the right of the plot if color_by == "cluster": # First legend for dataset and remodeling markers ax.legend(handles=combined_handles, bbox_to_anchor=(1.15, 1), loc="upper left") # Second legend for clusters ax.legend( handles=cluster_handles, bbox_to_anchor=(1.15, 0.5), loc="center left", title="Clusters", ) else: # Single legend if not coloring by cluster ax.legend(handles=combined_handles, bbox_to_anchor=(1.15, 1), loc="upper left") # Adjust layout to prevent legend cutoff plt.tight_layout() ax.set_xlabel(f"{aligned_umap_key.replace('X_', '')}1") ax.set_ylabel(f"{aligned_umap_key.replace('X_', '')}2") ax.set_xticks([]) ax.set_yticks([]) show = scanpy.settings.autoshow if show is None else show scanpy.pl._utils.savefig_or_show("aligned_umap", show=show, save=save) if show: return None return ax
[docs] def remodeling_sankey( data: AnnData, data2: AnnData, cluster_key: str = "leiden", ax: plt.Axes | None = None, aspect: int = 20, fontsize: int = 12, figsize: tuple[float, float] = (10, 11), show: bool | None = None, save: bool | str | None = None, ) -> plt.Axes | None: """Sankey diagram of cluster re-assignment between two conditions. Parameters ---------- data, data2 Two :class:`~anndata.AnnData` objects with identical observation indices and a categorical *cluster_key* in ``.obs``. cluster_key Column name that encodes cluster memberships. aspect, fontsize, figsize, ax Styling options passed to :func:`pysankey.sankey` or matplotlib. show, save Behaviour identical to other plotting functions. Returns ------- Axes containing the Sankey plot if ``show`` is ``False``. """ if ax is None: fig, ax = plt.subplots(figsize=figsize) # Check that the anndata objects are aligned assert (data.obs_names == data2.obs_names).all() pysankey.sankey( left=data.obs[cluster_key], right=data2.obs[cluster_key], aspect=aspect, # colorDict=colorDict, fontsize=fontsize, color_gradient=False, # leftLabels=[ # "nucleus", # "cytosol", # "mitochondrion", # "ER", # "plasma memb. & actin", # "endo-lysosome & trans-Golgi", # "ERGIC/Golgi", # "translation/RNA granules", # "peroxisome", # ], # rightLabels=[ # "nucleus", # "cytosol", # "mitochondrion", # "ER", # "plasma memb. & actin", # "endo-lysosome & trans-Golgi", # "translation/RNA granules", # "peroxisome", # "COPI vesicle", # ], ax=ax, ) show = scanpy.settings.autoshow if show is None else show scanpy.pl._utils.savefig_or_show("remodeling_sankey", show=show, save=save) if show: return None return ax
[docs] def mr_plot( data: AnnData, mr_key: str = "mr_scores", ax: plt.Axes = None, m_cutoffs: list[float] = [2, 3, 4], r_cutoffs: list[float] = [0.68, 0.81, 0.93], highlight_hits: bool = True, show: bool | None = None, save: bool | str | None = None, **kwargs, ) -> Axes | None: """MR-plot for simultaneous visualisation of *M* and *R* scores. *M* corresponds to ``-log10(q-value)`` and reflects statistical significance; *R* is a measure of consistencey of the re-localisation between replicates. Parameters ---------- data AnnData object after :func:`grassp.tl.mr_score`. mr_key Prefix of the keys written by ``mr_score`` (defaults to ``"mr_scores"`` resulting in ``mr_scores_M`` and ``mr_scores_R`` in ``.obs``). m_cutoffs, r_cutoffs Horizontal/vertical guideline positions (lenient → stringent). highlight_hits If ``True`` mark proteins passing the *lenient* thresholds in red. ax, show, save, **kwargs Standard matplotlib/scanpy plotting options. Returns ------- Returns the Axes object if ``show`` is ``False``. """ if ax is None: fig, ax = plt.subplots(figsize=(10, 6)) try: m_scores = data.obs[f"{mr_key}_M"] r_scores = data.obs[f"{mr_key}_R"] except KeyError: raise ValueError( f"MR scores not found in data.obs['{mr_key}_M/R'], run gr.tl.mr_score first" ) # Plot data ax.scatter(m_scores, r_scores, alpha=0.5, s=10, color="black", marker=".", **kwargs) if highlight_hits: hits = (m_scores >= m_cutoffs[0]) & (r_scores >= r_cutoffs[0]) ax.scatter(m_scores[hits], r_scores[hits], color="red", s=20, marker=".") # Add cutoff lines colors = ["gray", "darkgray", "lightgray"] for m_cut, color in zip(m_cutoffs, colors): ax.axvline(m_cut, color=color, linestyle="--", alpha=0.5) for r_cut, color in zip(r_cutoffs, colors): ax.axhline(r_cut, color=color, linestyle="--", alpha=0.5) # Set x-axis limits ax.set_xlim(0, np.max(m_scores) + 5) # Set labels ax.set_xlabel("M score (-log10 Q-value)") ax.set_ylabel("R score (minimum correlation)") ax.set_title("MR Plot") show = scanpy.settings.autoshow if show is None else show scanpy.pl._utils.savefig_or_show("remodeling_score", show=show, save=save) if show: return None return ax