from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from anndata import AnnData
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy
import seaborn as sns
import sklearn.metrics
from .clustering import knn_annotation
def class_balance(
    data: AnnData, label_key: str, min_class_size: int = 10, seed: int = 42
) -> AnnData:
    """Return a balanced subset with equally sized clusters.
    Samples the same number of observations from each category in
    ``data.obs[label_key]`` (the size is determined by the smallest class).
    Parameters
    ----------
    data
        Input :class:`~anndata.AnnData`.
    label_key
        Observation column with cluster or class labels.
    min_class_size
        Raise an error if the smallest class contains fewer than this
        number of observations (default ``10``).
    seed
        Random seed for reproducible sampling.
    Returns
    -------
    An AnnData *view* containing the balanced subset.
    """
    # Check if label_key is in adata.obs
    if label_key not in data.obs.columns:
        raise ValueError(f"Label key {label_key} not found in adata.obs")
    # Remove all samples with missing labels
    data_sub = data[data.obs[label_key].notna()]
    # Check if smallest class has at least min_class_size samples
    min_class_s = data_sub.obs[label_key].value_counts().min()
    min_class = data_sub.obs[label_key].value_counts().idxmin()
    if min_class_s < min_class_size:
        raise ValueError(
            f"Smallest class ({min_class}) has less than {min_class_size} samples."
        )
    if min_class_s < 10:
        warnings.warn(
            f"Smallest class ({min_class}) has less than 10 samples, this might not yield a stable score."
        )
    obs_names = []
    for label in data_sub.obs[label_key].unique():
        obs_names.extend(
            data_sub.obs[data_sub.obs[label_key] == label]
            .sample(min_class_s, replace=False, random_state=seed)
            .index.values
        )
    data_sub = data_sub[obs_names, :]
    return data_sub
[docs]
def silhouette_score(
    data, gt_col, use_rep="X_umap", key_added="silhouette", inplace=True
) -> None | np.ndarray:
    """Per-group silhouette scores.
    Computes the silhouette score for each group in ``data.obs[gt_col]``.
    Parameters
    ----------
    data
        AnnData object containing an embedding in ``.obsm``.
    gt_col
        Column in ``data.obs`` with cluster labels.
    use_rep
        Key of the embedding to evaluate (default ``"X_umap"``).
    key_added
        Base key under which results are stored (default ``"silhouette"``).
    inplace
        If ``True`` (default) store return ``None``, if ``False`` return the silhouette scores.
    Returns
    -------
    If ``inplace`` is ``True``:
        ``data.obs[key_added]``
            Vector of silhouette scores.
        ``data.uns[key_added]['mean_silhouette_score']``
            Global mean.
        ``data.uns[key_added]['cluster_mean_silhouette']``
            Mapping of cluster → mean score.
    If ``inplace`` is ``False``:
        Vector of silhouette scores.
    """
    mask = data.obs[gt_col].notna()
    data_sub = data[mask]
    sub_obs = data_sub.obs.copy()
    ss = sklearn.metrics.silhouette_samples(data_sub.obsm[use_rep], sub_obs[gt_col])
    if inplace:
        sub_obs[key_added] = ss
        cluster_mean_ss = sub_obs.groupby(gt_col)[key_added].mean()
        data.uns[key_added] = {
            "mean_silhouette_score": ss.mean(),
            "cluster_mean_silhouette": cluster_mean_ss.to_dict(),
            "cluster_balanced_silhouette_score": cluster_mean_ss.mean(),
        }
        data.obs.loc[mask, key_added] = ss
    else:
        return ss 
[docs]
def calinski_habarasz_score(
    data,
    gt_col,
    use_rep="X_umap",
    key_added="ch_score",
    class_balance=False,
    inplace=True,
    seed=42,
) -> None | float:
    """Calinski–Harabasz score of cluster compactness vs separation.
    Parameters
    ----------
    data
        AnnData with an embedding under ``.obsm[use_rep]``.
    gt_col
        Observation column containing cluster assignments.
    use_rep
        Name of embedding to use (default ``"X_umap"``).
    key_added
        Key under which to store the score when ``inplace`` is ``True``.
    class_balance
        If ``True`` subsample each cluster to equal size before computing the
        score (calls ``class_balance`` internally).
    inplace, seed
        Standard behaviour flags.
    Returns
    -------
    If ``inplace`` is ``True``:
        ``data.uns[key_added]``
            Score.
    If ``inplace`` is ``False``:
        Score.
    """
    mask = data.obs[gt_col].notna()
    data_sub = data[mask]
    if class_balance:
        min_class_size = data_sub.obs[gt_col].value_counts().min()
        if min_class_size < 10:
            warnings.warn(
                "Smallest class has less than 10 samples, this might not yield a stable score."
            )
        obs_names = []
        for label in data_sub.obs[gt_col].unique():
            obs_names.extend(
                data_sub.obs[data_sub.obs[gt_col] == label]
                .sample(min_class_size, replace=False, random_state=seed)
                .index.values
            )
        data_sub = data_sub[obs_names, :]
    ch = sklearn.metrics.calinski_harabasz_score(data_sub.obsm[use_rep], data_sub.obs[gt_col])
    if inplace:
        data.uns[key_added] = ch
    else:
        return ch 
def qsep_score(
    data: AnnData,
    gt_col: str,
    use_rep: str = "X",
    distance_key: str = "full_distances",
    inplace: bool = True,
) -> None | np.ndarray:
    """QSep cluster-separation metric for spatial proteomics.
    Implements the *QSep* statistic from Gatto *et al.* (2014) which
    measures within- vs between-cluster distances.
    Parameters
    ----------
    data
        AnnData object.
    gt_col
        Observation column with ground-truth cluster labels.
    use_rep
        Representation used for distance computation – ``"X"`` or a key in
        ``data.obsm``.
    distance_key
        Column name to store per-protein mean distances (only when
        ``inplace`` is ``True``).
    inplace
        Control write-back vs return behaviour.
    Returns
    -------
    If ``inplace`` is ``True``:
        ``None``
    If ``inplace`` is ``False``:
        ``cluster_distances``
    """
    # Get data matrix
    if use_rep == "X":
        X = data.X
    else:
        X = data.obsm[use_rep]
    # Calculate pairwise distances between all points
    full_distances = sklearn.metrics.pairwise_distances(X)
    # Get valid clusters (non-NA)
    mask = data.obs[gt_col].notna()
    valid_clusters = data.obs[gt_col][mask].unique()
    # Calculate cluster distances
    cluster_distances = np.zeros((len(valid_clusters), len(valid_clusters)))
    cluster_indices = {
        cluster: np.where(data.obs[gt_col] == cluster)[0] for cluster in valid_clusters
    }
    for i, cluster1 in enumerate(valid_clusters):
        for j in range(i, len(valid_clusters)):
            # for j, cluster2 in enumerate(valid_clusters[i + 1 :]):
            cluster2 = valid_clusters[j]
            idx1 = cluster_indices[cluster1]
            idx2 = cluster_indices[cluster2]
            # Get submatrix of distances between clusters
            submatrix = full_distances[np.ix_(idx1, idx2)]
            cluster_distances[i, j] = np.mean(submatrix)
            cluster_distances[j, i] = np.mean(submatrix)
    if inplace:
        # Store full distances
        data.obs[distance_key] = pd.Series(
            np.mean(full_distances, axis=1), index=data.obs.index
        )
        # Store cluster distances and metadata
        data.uns["cluster_distances"] = {
            "distances": cluster_distances,
            "clusters": valid_clusters.tolist(),
        }
    else:
        return cluster_distances
def knn_f1_score(data, gt_col, pred_col=None, weights=None, average="macro"):
    """F1 score.
    Parameters
    ----------
    data
        AnnData object.
    gt_col
        Observation column with ground-truth labels.
    pred_col
        Observation column with predicted labels.
    weights
        Weights for the F1 score.
    average
        Average method for the F1 score. If ``None`` the score for each label is returned.
    Returns
    -------
    F1 score.
    """
    if pred_col is None:
        knnres = knn_annotation(data, gt_col, inplace=False, min_probability=0)
        pred = knnres["labels"][knnres["probabilities"].argmax(axis=1)]
    else:
        pred = data.obs[pred_col]
    gt = data.obs[gt_col]
    mask = gt.notna() & pred.notna()
    y_true_raw = gt[mask]
    y_pred_raw = pred[mask]
    if weights is not None:
        labels = list(weights.keys())
        cats = (
            pd.Index(labels)
            .union(pd.Index(y_true_raw.unique()))
            .union(pd.Index(y_pred_raw.unique()))
        )
        y_true = pd.Categorical(y_true_raw, categories=cats).codes
        y_pred = pd.Categorical(y_pred_raw, categories=cats).codes
        label_idx = [cats.get_loc(label) for label in labels]
        f1_arr = sklearn.metrics.f1_score(y_true, y_pred, average=None, labels=label_idx)
        w = np.asarray([weights[label] for label in labels], float)
        w /= w.sum() if w.sum() > 0 else 1.0
        if average is None:
            return f1_arr * w
        return float((f1_arr * w).sum())
    else:
        cats = pd.Index(y_true_raw.unique()).union(pd.Index(y_pred_raw.unique()))
        y_true = pd.Categorical(y_true_raw, categories=cats).codes
        y_pred = pd.Categorical(y_pred_raw, categories=cats).codes
        return sklearn.metrics.f1_score(y_true, y_pred, average=average)
def knn_confusion_matrix(data, gt_col, pred_col=None, soft=False, cluster=False, plot=True):
    """Plot the confusion matrix of KNN-predicted versus ground-truth labels.
    Parameters
    ----------
    data
        AnnData object.
    gt_col
        Observation column with ground-truth labels.
    pred_col
        Observation column with predicted labels. If None, KNN annotation is computed.
    soft
        If True, use probabilistic (soft) confusion matrix instead of hard assignments.
    plot=True
        If True, plot the heatmap of the confusion matrix, otherwise return the confusion matrix.
    cluster
        If True, reorder matrix rows and columns via hierarchical clustering for visualization.
    Returns
    -------
    None. Displays a heatmap of the confusion matrix if plot is True, otherwise returns the confusion matrix.
    """
    # Notation: n observations, g ground truth label classes
    if pred_col is None:
        knnres = knn_annotation(data, gt_col, inplace=False, min_probability=0)
    else:
        knnres = {
            "probabilities": data.obsm[f"{pred_col}_probabilities"],
            "labels": data.obs[pred_col],
            "one_hot_labels": data.obsm[f"{pred_col}_one_hot_labels"],
        }
    if soft:
        M = knnres["one_hot_labels"].T @ knnres["probabilities"]  # shape (g, g)
        # Row-normalize to get fractions (each row sums to 1)
        cm = M / M.sum(axis=1, keepdims=True)
        cm = np.nan_to_num(cm, nan=0.0)
        labels = list(knnres["labels"])  # consistent label order with matrix axes
    else:
        gt = data.obs[gt_col]
        pred = knnres["labels"][knnres["probabilities"].argmax(axis=1)]
        mask = gt.notna() & pred.notna()
        y_true_raw = gt[mask]
        y_pred_raw = pred[mask]
        # Ensure we know the label order used in the confusion matrix
        labels = list(pd.Index(y_true_raw.unique()).union(pd.Index(y_pred_raw.unique())))
        cm = sklearn.metrics.confusion_matrix(y_true_raw, y_pred_raw, labels=labels)
        cm = cm / cm.sum(axis=1, keepdims=True)
    if not plot:
        return cm
    plt.figure(figsize=(10, 10))
    # Derive a single ordering for both axes so labels align and high values
    # lie near the diagonal. Use hierarchical clustering on a symmetrized matrix.
    if cluster:
        sym = (cm + cm.T) / 2.0
        order = scipy.cluster.hierarchy.leaves_list(
            scipy.cluster.hierarchy.linkage(1 - sym, method="ward")
        )
        cm_ordered = cm[np.ix_(order, order)]
        ordered_labels = [labels[i] for i in order]
    else:
        ordered_labels = labels
        cm_ordered = cm
    sns.heatmap(
        cm_ordered,
        annot=True,
        fmt=".2f",
        cmap="rocket_r",
        cbar=True,
        vmax=1,
        vmin=0,
        linewidths=0.5,
        square=True,
        xticklabels=ordered_labels,
        yticklabels=ordered_labels,
    )
    plt.xlabel("Predicted label")
    plt.ylabel("True label")
    plt.title(f"{'Soft' if soft else 'Hard'} confusion matrix from knn annotation")
    plt.tight_layout()
    plt.show()