grassp.tl.svm_train

Contents

grassp.tl.svm_train#

svm_train(data, gt_col, C_range=array([0.0625, 0.125, 0.25, 0.5, 1., 2., 4., 8., 16.]), gamma_range=array([1.e-03, 1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02]), class_weight='balanced', cv_splits=5, cv_repeats=20, n_jobs=-1, random_state=None, key_added='svm', inplace=True)[source]#

Train SVM classifier with hyperparameter tuning using marker proteins.

Performs grid search over C and gamma parameters using repeated stratified cross-validation. Best hyperparameters are stored in .uns for later use with svm_annotation().

Parameters:
data AnnData

anndata.AnnData object with proteins as observations.

gt_col str

Observation column containing marker annotations. Proteins with NaN values are considered unknown and excluded from training.

C_range ndarray (default: array([ 0.0625,  0.125 ,  0.25  ,  0.5   ,  1.    ,  2.    ,  4.    ,         8.    , 16.    ]))

Array of C (regularization) values to search. Default: 2^-4 to 2^4.

gamma_range ndarray (default: array([1.e-03, 1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02]))

Array of gamma (kernel coefficient) values. Default: 10^-3 to 10^2.

cv_splits int (default: 5)

Number of cross-validation folds (default 5).

cv_repeats int (default: 20)

Number of CV repetitions (default 20). Total fits per parameter combination: cv_splits × cv_repeats.

n_jobs int (default: -1)

Number of parallel jobs. -1 uses all available cores.

random_state int | None (default: None)

Random seed for reproducibility.

key_added str (default: 'svm')

Key prefix for storing results in .uns (default "svm").

inplace bool (default: True)

If True store results in .uns; if False return grid search object and dictionary with best parameters and CV results. This can be useful if you want to inspect the grid search object or use the best parameters for other tasks.

class_weight None | dict | Literal['balanced']

Return type:

dict | None

Returns:

None or dict If inplace=False, returns dictionary with best parameters and CV results. Otherwise modifies data.uns[f"{key_added}.params"] in place.

Examples

>>> import grassp as gr
>>> adata = gr.ds.hein_2024(enrichment="enriched")

# When actually training, increase cv_repeats and cv_splits # We recommend >20 repeats with 5 splits >>> gr.tl.svm_train(adata, gt_col=”hein2024_gt_component”, cv_repeats=2, cv_splits=2, random_state=42) Fitting 4 folds for each of 54 candidates, totalling 216 fits >>> adata.uns[“svm.params”][“best_params”] {‘C’: 2.0, ‘gamma’: 0.01}