nirs4all.operators.splitters.grouped_wrapper module
Grouped Splitter Wrapper for universal group support.
This module provides a wrapper that enables any sklearn-compatible splitter to work with grouped samples by aggregating samples by group, passing “virtual samples” to the inner splitter, and expanding fold indices back to the original dataset.
- class nirs4all.operators.splitters.grouped_wrapper.GroupedSplitterWrapper(splitter, aggregation='mean', y_aggregation=None)[source]
Bases:
BaseCrossValidatorWraps any sklearn-compatible splitter to add group-awareness.
This wrapper aggregates samples by group into “virtual samples”, passes them to the inner splitter, and expands the fold indices back to the original sample space. This ensures that all samples from the same group are always in the same fold (train or test), preventing data leakage.
- Parameters:
splitter (BaseCrossValidator) – Any sklearn-compatible cross-validator (e.g., KFold, ShuffleSplit, StratifiedKFold).
aggregation (str, default="mean") – Method for aggregating X features within groups: - “mean”: Use group centroid (average of all samples) - “median”: Use group median (robust to outliers) - “first”: Use first sample in each group (fast, no aggregation)
y_aggregation (str or None, default=None) – Method for aggregating y values within groups. If None, inferred from splitter type: - “mean”: For regression (continuous y) - “mode”: For classification (categorical y) - “first”: Use first y value in group
Examples
>>> from sklearn.model_selection import KFold, ShuffleSplit, StratifiedKFold >>> import numpy as np >>> >>> # Basic usage with KFold >>> X = np.random.randn(100, 10) >>> y = np.random.randn(100) >>> groups = np.repeat(np.arange(20), 5) # 20 groups, 5 samples each >>> >>> wrapper = GroupedSplitterWrapper(KFold(n_splits=5)) >>> for train_idx, test_idx in wrapper.split(X, y, groups=groups): ... # train_idx and test_idx are original sample indices ... # All samples from the same group are in the same fold ... train_groups = set(groups[train_idx]) ... test_groups = set(groups[test_idx]) ... assert len(train_groups & test_groups) == 0 # No overlap >>> >>> # Usage with ShuffleSplit >>> wrapper = GroupedSplitterWrapper(ShuffleSplit(n_splits=1, test_size=0.2)) >>> for train_idx, test_idx in wrapper.split(X, y, groups=groups): ... pass # Groups are respected >>> >>> # Usage with StratifiedKFold (stratifies on aggregated y) >>> y_class = np.random.randint(0, 3, 100) >>> wrapper = GroupedSplitterWrapper( ... StratifiedKFold(n_splits=3), ... y_aggregation="mode" ... ) >>> for train_idx, test_idx in wrapper.split(X, y_class, groups=groups): ... pass # Groups are respected, stratification on group mode
Notes
The wrapper is transparent when no groups are provided - it simply delegates to the inner splitter without any aggregation.
See also
sklearn.model_selection.GroupKFoldNative group-aware K-fold splitter.
sklearn.model_selection.GroupShuffleSplitNative group-aware shuffle split.
nirs4all.operators.splitters.SPXYGFoldSPXY-based group-aware splitter.