Source code for grassp.tools.scoring

from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from anndata import AnnData

import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy
import seaborn as sns
import sklearn.metrics

from .clustering import knn_annotation


def class_balance(
    data: AnnData, label_key: str, min_class_size: int = 10, seed: int = 42
) -> AnnData:
    """Return a balanced subset with equally sized clusters.

    Samples the same number of observations from each category in
    ``data.obs[label_key]`` (the size is determined by the smallest class).

    Parameters
    ----------
    data
        Input :class:`~anndata.AnnData`.
    label_key
        Observation column with cluster or class labels.
    min_class_size
        Raise an error if the smallest class contains fewer than this
        number of observations (default ``10``).
    seed
        Random seed for reproducible sampling.

    Returns
    -------
    An AnnData *view* containing the balanced subset.
    """
    # Check if label_key is in adata.obs
    if label_key not in data.obs.columns:
        raise ValueError(f"Label key {label_key} not found in adata.obs")
    # Remove all samples with missing labels
    data_sub = data[data.obs[label_key].notna()]
    # Check if smallest class has at least min_class_size samples
    min_class_s = data_sub.obs[label_key].value_counts().min()
    min_class = data_sub.obs[label_key].value_counts().idxmin()
    if min_class_s < min_class_size:
        raise ValueError(
            f"Smallest class ({min_class}) has less than {min_class_size} samples."
        )
    if min_class_s < 10:
        warnings.warn(
            f"Smallest class ({min_class}) has less than 10 samples, this might not yield a stable score."
        )

    obs_names = []
    for label in data_sub.obs[label_key].unique():
        obs_names.extend(
            data_sub.obs[data_sub.obs[label_key] == label]
            .sample(min_class_s, replace=False, random_state=seed)
            .index.values
        )
    data_sub = data_sub[obs_names, :]
    return data_sub


[docs] def silhouette_score( data, gt_col, use_rep="X_umap", key_added="silhouette", inplace=True ) -> None | np.ndarray: """Per-group silhouette scores. Computes the silhouette score for each group in ``data.obs[gt_col]``. Parameters ---------- data AnnData object containing an embedding in ``.obsm``. gt_col Column in ``data.obs`` with cluster labels. use_rep Key of the embedding to evaluate (default ``"X_umap"``). key_added Base key under which results are stored (default ``"silhouette"``). inplace If ``True`` (default) store return ``None``, if ``False`` return the silhouette scores. Returns ------- If ``inplace`` is ``True``: ``data.obs[key_added]`` Vector of silhouette scores. ``data.uns[key_added]['mean_silhouette_score']`` Global mean. ``data.uns[key_added]['cluster_mean_silhouette']`` Mapping of cluster → mean score. If ``inplace`` is ``False``: Vector of silhouette scores. """ mask = data.obs[gt_col].notna() data_sub = data[mask] sub_obs = data_sub.obs.copy() ss = sklearn.metrics.silhouette_samples(data_sub.obsm[use_rep], sub_obs[gt_col]) if inplace: sub_obs[key_added] = ss cluster_mean_ss = sub_obs.groupby(gt_col)[key_added].mean() data.uns[key_added] = { "mean_silhouette_score": ss.mean(), "cluster_mean_silhouette": cluster_mean_ss.to_dict(), "cluster_balanced_silhouette_score": cluster_mean_ss.mean(), } data.obs.loc[mask, key_added] = ss else: return ss
[docs] def calinski_habarasz_score( data, gt_col, use_rep="X_umap", key_added="ch_score", class_balance=False, inplace=True, seed=42, ) -> None | float: """Calinski–Harabasz score of cluster compactness vs separation. Parameters ---------- data AnnData with an embedding under ``.obsm[use_rep]``. gt_col Observation column containing cluster assignments. use_rep Name of embedding to use (default ``"X_umap"``). key_added Key under which to store the score when ``inplace`` is ``True``. class_balance If ``True`` subsample each cluster to equal size before computing the score (calls ``class_balance`` internally). inplace, seed Standard behaviour flags. Returns ------- If ``inplace`` is ``True``: ``data.uns[key_added]`` Score. If ``inplace`` is ``False``: Score. """ mask = data.obs[gt_col].notna() data_sub = data[mask] if class_balance: min_class_size = data_sub.obs[gt_col].value_counts().min() if min_class_size < 10: warnings.warn( "Smallest class has less than 10 samples, this might not yield a stable score." ) obs_names = [] for label in data_sub.obs[gt_col].unique(): obs_names.extend( data_sub.obs[data_sub.obs[gt_col] == label] .sample(min_class_size, replace=False, random_state=seed) .index.values ) data_sub = data_sub[obs_names, :] ch = sklearn.metrics.calinski_harabasz_score(data_sub.obsm[use_rep], data_sub.obs[gt_col]) if inplace: data.uns[key_added] = ch else: return ch
def qsep_score( data: AnnData, gt_col: str, use_rep: str = "X", distance_key: str = "full_distances", inplace: bool = True, ) -> None | np.ndarray: """QSep cluster-separation metric for spatial proteomics. Implements the *QSep* statistic from Gatto *et al.* (2014) which measures within- vs between-cluster distances. Parameters ---------- data AnnData object. gt_col Observation column with ground-truth cluster labels. use_rep Representation used for distance computation – ``"X"`` or a key in ``data.obsm``. distance_key Column name to store per-protein mean distances (only when ``inplace`` is ``True``). inplace Control write-back vs return behaviour. Returns ------- If ``inplace`` is ``True``: ``None`` If ``inplace`` is ``False``: ``cluster_distances`` """ # Get data matrix if use_rep == "X": X = data.X else: X = data.obsm[use_rep] # Calculate pairwise distances between all points full_distances = sklearn.metrics.pairwise_distances(X) # Get valid clusters (non-NA) mask = data.obs[gt_col].notna() valid_clusters = data.obs[gt_col][mask].unique() # Calculate cluster distances cluster_distances = np.zeros((len(valid_clusters), len(valid_clusters))) cluster_indices = { cluster: np.where(data.obs[gt_col] == cluster)[0] for cluster in valid_clusters } for i, cluster1 in enumerate(valid_clusters): for j in range(i, len(valid_clusters)): # for j, cluster2 in enumerate(valid_clusters[i + 1 :]): cluster2 = valid_clusters[j] idx1 = cluster_indices[cluster1] idx2 = cluster_indices[cluster2] # Get submatrix of distances between clusters submatrix = full_distances[np.ix_(idx1, idx2)] cluster_distances[i, j] = np.mean(submatrix) cluster_distances[j, i] = np.mean(submatrix) if inplace: # Store full distances data.obs[distance_key] = pd.Series( np.mean(full_distances, axis=1), index=data.obs.index ) # Store cluster distances and metadata data.uns["cluster_distances"] = { "distances": cluster_distances, "clusters": valid_clusters.tolist(), } else: return cluster_distances def knn_f1_score(data, gt_col, pred_col=None, weights=None, average="macro"): """F1 score. Parameters ---------- data AnnData object. gt_col Observation column with ground-truth labels. pred_col Observation column with predicted labels. weights Weights for the F1 score. average Average method for the F1 score. If ``None`` the score for each label is returned. Returns ------- F1 score. """ if pred_col is None: knnres = knn_annotation(data, gt_col, inplace=False, min_probability=0) pred = knnres["labels"][knnres["probabilities"].argmax(axis=1)] else: pred = data.obs[pred_col] gt = data.obs[gt_col] mask = gt.notna() & pred.notna() y_true_raw = gt[mask] y_pred_raw = pred[mask] if weights is not None: labels = list(weights.keys()) cats = ( pd.Index(labels) .union(pd.Index(y_true_raw.unique())) .union(pd.Index(y_pred_raw.unique())) ) y_true = pd.Categorical(y_true_raw, categories=cats).codes y_pred = pd.Categorical(y_pred_raw, categories=cats).codes label_idx = [cats.get_loc(label) for label in labels] f1_arr = sklearn.metrics.f1_score(y_true, y_pred, average=None, labels=label_idx) w = np.asarray([weights[label] for label in labels], float) w /= w.sum() if w.sum() > 0 else 1.0 if average is None: return f1_arr * w return float((f1_arr * w).sum()) else: cats = pd.Index(y_true_raw.unique()).union(pd.Index(y_pred_raw.unique())) y_true = pd.Categorical(y_true_raw, categories=cats).codes y_pred = pd.Categorical(y_pred_raw, categories=cats).codes return sklearn.metrics.f1_score(y_true, y_pred, average=average) def knn_confusion_matrix(data, gt_col, pred_col=None, soft=False, cluster=False, plot=True): """Plot the confusion matrix of KNN-predicted versus ground-truth labels. Parameters ---------- data AnnData object. gt_col Observation column with ground-truth labels. pred_col Observation column with predicted labels. If None, KNN annotation is computed. soft If True, use probabilistic (soft) confusion matrix instead of hard assignments. plot=True If True, plot the heatmap of the confusion matrix, otherwise return the confusion matrix. cluster If True, reorder matrix rows and columns via hierarchical clustering for visualization. Returns ------- None. Displays a heatmap of the confusion matrix if plot is True, otherwise returns the confusion matrix. """ # Notation: n observations, g ground truth label classes if pred_col is None: knnres = knn_annotation(data, gt_col, inplace=False, min_probability=0) else: knnres = { "probabilities": data.obsm[f"{pred_col}_probabilities"], "labels": data.obs[pred_col], "one_hot_labels": data.obsm[f"{pred_col}_one_hot_labels"], } if soft: M = knnres["one_hot_labels"].T @ knnres["probabilities"] # shape (g, g) # Row-normalize to get fractions (each row sums to 1) cm = M / M.sum(axis=1, keepdims=True) cm = np.nan_to_num(cm, nan=0.0) labels = list(knnres["labels"]) # consistent label order with matrix axes else: gt = data.obs[gt_col] pred = knnres["labels"][knnres["probabilities"].argmax(axis=1)] mask = gt.notna() & pred.notna() y_true_raw = gt[mask] y_pred_raw = pred[mask] # Ensure we know the label order used in the confusion matrix labels = list(pd.Index(y_true_raw.unique()).union(pd.Index(y_pred_raw.unique()))) cm = sklearn.metrics.confusion_matrix(y_true_raw, y_pred_raw, labels=labels) cm = cm / cm.sum(axis=1, keepdims=True) if not plot: return cm plt.figure(figsize=(10, 10)) # Derive a single ordering for both axes so labels align and high values # lie near the diagonal. Use hierarchical clustering on a symmetrized matrix. if cluster: sym = (cm + cm.T) / 2.0 order = scipy.cluster.hierarchy.leaves_list( scipy.cluster.hierarchy.linkage(1 - sym, method="ward") ) cm_ordered = cm[np.ix_(order, order)] ordered_labels = [labels[i] for i in order] else: ordered_labels = labels cm_ordered = cm sns.heatmap( cm_ordered, annot=True, fmt=".2f", cmap="rocket_r", cbar=True, vmax=1, vmin=0, linewidths=0.5, square=True, xticklabels=ordered_labels, yticklabels=ordered_labels, ) plt.xlabel("Predicted label") plt.ylabel("True label") plt.title(f"{'Soft' if soft else 'Hard'} confusion matrix from knn annotation") plt.tight_layout() plt.show()