grassp.tl.svm_train#
- svm_train(data, gt_col, C_range=array([0.0625, 0.125, 0.25, 0.5, 1., 2., 4., 8., 16.]), gamma_range=array([1.e-03, 1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02]), class_weight='balanced', cv_splits=5, cv_repeats=20, n_jobs=-1, random_state=None, key_added='svm', inplace=True)[source]#
Train SVM classifier with hyperparameter tuning using marker proteins.
Performs grid search over C and gamma parameters using repeated stratified cross-validation. Best hyperparameters are stored in
.unsfor later use withsvm_annotation().- Parameters:
- data
AnnData anndata.AnnDataobject with proteins as observations.- gt_col
str Observation column containing marker annotations. Proteins with NaN values are considered unknown and excluded from training.
- C_range
ndarray(default:array([ 0.0625, 0.125 , 0.25 , 0.5 , 1. , 2. , 4. , 8. , 16. ])) Array of C (regularization) values to search. Default: 2^-4 to 2^4.
- gamma_range
ndarray(default:array([1.e-03, 1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02])) Array of gamma (kernel coefficient) values. Default: 10^-3 to 10^2.
- cv_splits
int(default:5) Number of cross-validation folds (default 5).
- cv_repeats
int(default:20) Number of CV repetitions (default 20). Total fits per parameter combination:
cv_splits × cv_repeats.- n_jobs
int(default:-1) Number of parallel jobs. -1 uses all available cores.
- random_state
int|None(default:None) Random seed for reproducibility.
- key_added
str(default:'svm') Key prefix for storing results in
.uns(default"svm").- inplace
bool(default:True) If
Truestore results in.uns; ifFalsereturn grid search object and dictionary with best parameters and CV results. This can be useful if you want to inspect the grid search object or use the best parameters for other tasks.- class_weight None | dict | Literal['balanced']
- data
- Return type:
- Returns:
None or dict If
inplace=False, returns dictionary with best parameters and CV results. Otherwise modifiesdata.uns[f"{key_added}.params"]in place.
Examples
>>> import grassp as gr >>> adata = gr.ds.hein_2024(enrichment="enriched")
# When actually training, increase cv_repeats and cv_splits # We recommend >20 repeats with 5 splits >>> gr.tl.svm_train(adata, gt_col=”hein2024_gt_component”, cv_repeats=2, cv_splits=2, random_state=42) Fitting 4 folds for each of 54 candidates, totalling 216 fits >>> adata.uns[“svm.params”][“best_params”] {‘C’: 2.0, ‘gamma’: 0.01}