from __future__ import annotations
import inspect
import warnings
from typing import Any, Dict, Tuple, TYPE_CHECKING, List, Union
import copy
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.operators.splitters import GroupedSplitterWrapper
logger = get_logger(__name__)
if TYPE_CHECKING: # pragma: no cover
from nirs4all.data.dataset import SpectroDataset
# Native group-aware splitter class names (sklearn and nirs4all)
# These splitters have built-in group support and don't need force_group
_NATIVE_GROUP_SPLITTERS = frozenset({
"GroupKFold",
"GroupShuffleSplit",
"LeaveOneGroupOut",
"LeavePGroupsOut",
"StratifiedGroupKFold",
"SPXYGFold",
"BinnedStratifiedGroupKFold",
})
def _is_native_group_splitter(splitter: Any) -> bool:
"""Check if splitter has native group support.
Returns True if the splitter is a known group-aware splitter that
properly handles the 'groups' parameter without needing force_group.
"""
return splitter.__class__.__name__ in _NATIVE_GROUP_SPLITTERS
def _needs(splitter: Any) -> Tuple[bool, bool]:
"""Return booleans *(needs_y, needs_groups)* for the given splitter.
Introspects the signature of ``split`` *plus* estimator tags (when
available) so it works for *any* class respecting the sklearn contract.
"""
split_fn = getattr(splitter, "split", None)
if not callable(split_fn):
# No split method → cannot be a valid splitter
return False, False
sig = inspect.signature(split_fn)
params = sig.parameters
needs_y = "y" in params # and params["y"].default is inspect._empty
# Check if 'groups' parameter exists - sklearn group splitters have groups=None default
# but still require the parameter to be provided for proper operation
needs_g = "groups" in params
# Honour estimator tags (sklearn >=1.3)
if hasattr(splitter, "_get_tags"):
tags = splitter._get_tags()
needs_y = needs_y or tags.get("requires_y", False)
return needs_y, needs_g
[docs]
@register_controller
class CrossValidatorController(OperatorController):
"""Controller for **any** sklearn‑compatible splitter (native or custom)."""
priority = 10 # processed early but after mandatory pre‑processing steps
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool: # noqa: D401
"""Return *True* if *operator* behaves like a splitter.
**Criteria** – must expose a callable ``split`` whose first positional
argument is named *X*. Optional presence of ``get_n_splits`` is a plus
but not mandatory, so user‑defined simple splitters are still accepted.
Also matches on the 'split' keyword for group-aware splitting syntax.
"""
# Priority 1: Match on 'split' keyword (explicit workflow operator)
if keyword == "split":
return True
# Priority 2: Match dict with 'split' key
if isinstance(step, dict) and "split" in step:
return True
# Priority 3: Match objects with split() method (existing behavior)
if operator is None:
return False
split_fn = getattr(operator, "split", None)
if not callable(split_fn):
return False
try:
sig = inspect.signature(split_fn)
except (TypeError, ValueError): # edge‑cases: C‑extensions or cythonised
return True # accept – we can still attempt runtime call
params: List[inspect.Parameter] = [
p for p in sig.parameters.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
return bool(params) and params[0].name == "X"
[docs]
@classmethod
def use_multi_source(cls) -> bool: # noqa: D401
"""Cross‑validators themselves are single‑source operators."""
return False
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Cross-validators should not execute during prediction mode."""
return True
[docs]
def execute( # type: ignore[override]
self,
step_info: 'ParsedStep',
dataset: "SpectroDataset",
context: ExecutionContext,
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Any = None,
prediction_store: Any = None
):
"""Run ``operator.split`` and store the resulting folds on *dataset*.
* Smartly supplies ``y`` / ``groups`` only if required.
* Extracts groups from metadata if specified.
* Supports ``force_group`` parameter to wrap any splitter with group-awareness.
* Maps local indices back to the global index space.
* Stores the list of folds into the dataset for subsequent steps.
Parameters
----------
step_info : ParsedStep
Parsed step containing the operator and original step configuration.
dataset : SpectroDataset
The dataset to split.
context : ExecutionContext
Current execution context.
runtime_context : RuntimeContext
Runtime context with global settings.
source : int
Source index (-1 for combined sources).
mode : str
Execution mode ("train", "predict", or "explain").
loaded_binaries : Any
Pre-loaded binary data (not used).
prediction_store : Any
Store for predictions (not used).
Notes
-----
The ``force_group`` parameter enables any sklearn-compatible splitter
to work with grouped samples by wrapping it with ``GroupedSplitterWrapper``.
This aggregates samples by group, passes "virtual samples" to the splitter,
and expands fold indices back to the original dataset.
Example usage::
{"split": KFold(n_splits=5), "force_group": "Sample_ID"}
{"split": ShuffleSplit(test_size=0.2), "force_group": "ID", "aggregation": "median"}
"""
from nirs4all.pipeline.execution.result import StepOutput
op = step_info.operator
# Extract force_group and aggregation parameters from step dict
force_group = None
aggregation = "mean"
y_aggregation = None
if isinstance(step_info.original_step, dict):
force_group = step_info.original_step.get("force_group")
aggregation = step_info.original_step.get("aggregation", "mean")
y_aggregation = step_info.original_step.get("y_aggregation")
# In predict/explain mode, skip fold splitting entirely
if mode == "predict" or mode == "explain":
# Don't filter by partition - prediction data may be in "test" partition
local_context = context.with_partition(None)
needs_y, needs_g = _needs(op)
X = dataset.x(local_context, layout="2d", concat_source=True)
n_samples = X.shape[0]
# Build minimal kwargs for get_n_splits
kwargs: Dict[str, Any] = {}
if needs_y:
y = dataset.y(local_context)
if y is not None:
kwargs["y"] = y
n_folds = op.get_n_splits(**kwargs) if hasattr(op, "get_n_splits") else 1
dataset.set_folds([(list(range(n_samples)), [])] * n_folds)
return context, StepOutput()
# Extract group column specification from step dict (train mode only)
group_column = None
if isinstance(step_info.original_step, dict) and "group" in step_info.original_step:
group_column = step_info.original_step["group"]
if not isinstance(group_column, str):
raise TypeError(
f"Group column must be a string, got {type(group_column).__name__}"
)
# Warn if 'group' is used with a non-native-group splitter
# These splitters will silently ignore the groups parameter
# Suggest using 'force_group' instead for universal group support
if not _is_native_group_splitter(op):
splitter_name = op.__class__.__name__
warnings.warn(
f"⚠️ 'group' parameter specified with {splitter_name}, which does not "
f"natively support groups. The 'group' parameter will be ignored.\n"
f"💡 Use 'force_group' instead to enable group-aware splitting with any splitter:\n"
f" {{'split': {splitter_name}(...), 'force_group': '{group_column}'}}\n"
f"This will ensure all samples from the same group stay together in train/test.",
UserWarning,
stacklevel=2
)
# Handle force_group: wrap the splitter with GroupedSplitterWrapper
# This enables any sklearn-compatible splitter to work with groups
force_group_column = None
force_group_is_y = False # Track if force_group uses y-binning
n_bins = 5 # Default bins for y-binning
if force_group is not None:
if not isinstance(force_group, str):
raise TypeError(
f"force_group must be a string column name or 'y', got {type(force_group).__name__}"
)
# Check if force_group is "y" (special case: use binned y values as groups)
if force_group.lower() == "y":
force_group_is_y = True
# Extract n_bins from step dict if provided
if isinstance(step_info.original_step, dict):
n_bins = step_info.original_step.get("n_bins", 5)
if not isinstance(n_bins, int) or n_bins < 2:
raise ValueError(
f"n_bins must be an integer >= 2, got {n_bins}"
)
else:
force_group_column = force_group
# Validate aggregation parameter
valid_aggregations = ("mean", "median", "first")
if aggregation not in valid_aggregations:
raise ValueError(
f"aggregation must be one of {valid_aggregations}, got '{aggregation}'"
)
# Validate y_aggregation parameter if provided
valid_y_aggregations = ("mean", "mode", "first", None)
if y_aggregation not in valid_y_aggregations:
raise ValueError(
f"y_aggregation must be one of {valid_y_aggregations}, got '{y_aggregation}'"
)
# Wrap the splitter with GroupedSplitterWrapper
op = GroupedSplitterWrapper(
splitter=op,
aggregation=aggregation,
y_aggregation=y_aggregation
)
local_context = context.with_partition("train")
needs_y, needs_g = _needs(op)
# IMPORTANT: Only split on base samples (exclude augmented) to prevent data leakage
X = dataset.x(local_context, layout="2d", concat_source=True, include_augmented=False)
# Get the actual sample IDs from the indexer - these will be used to store folds
# with absolute sample IDs instead of positional indices, so folds remain valid
# even if samples are excluded later by sample_filter
base_sample_ids = dataset._indexer.x_indices( # noqa: SLF001
local_context.selector, include_augmented=False, include_excluded=False
)
y = None
if needs_y or force_group_column is not None or force_group_is_y:
# Get y for splitters that need it, or for force_group (wrapper may need it)
y = dataset.y(local_context, include_augmented=False)
# Get groups from metadata if available
# For force_group: always extract groups (wrapper requires them)
# For native group splitters: extract if needs_g is True
groups = None
effective_group_column = force_group_column or group_column
if force_group_is_y:
# force_group: "y" - use binned y values as groups
# This enables stratification on continuous targets
if y is None:
raise ValueError(
"force_group='y' specified but dataset.y returned None"
)
# Bin y values into n_bins quantile bins for group-aware splitting
# Each bin becomes a "group" that won't be split across train/test
groups = self._bin_y_for_groups(y, n_bins)
elif force_group_column is not None:
# force_group: always extract groups for the wrapper
if not hasattr(dataset, 'metadata_columns') or not dataset.metadata_columns:
raise ValueError(
f"force_group='{force_group_column}' specified but dataset has no metadata columns."
)
if force_group_column not in dataset.metadata_columns:
raise ValueError(
f"force_group column '{force_group_column}' not found in metadata.\n"
f"Available columns: {dataset.metadata_columns}"
)
try:
groups = dataset.metadata_column(force_group_column, local_context, include_augmented=False)
if len(groups) != X.shape[0]:
raise ValueError(
f"Group array length ({len(groups)}) doesn't match X rows ({X.shape[0]})"
)
except Exception as e:
raise ValueError(
f"Failed to extract groups from force_group column '{force_group_column}': {e}"
) from e
elif needs_g and (group_column is not None or _is_native_group_splitter(op)):
# Only extract groups if:
# 1. Explicit group column specified (user requested grouping), OR
# 2. Splitter is a native group splitter (GroupKFold, etc.) that requires groups
# Note: Many sklearn splitters (KFold, ShuffleSplit, etc.) have 'groups' parameter
# for API compatibility, but don't require it. We should NOT auto-assign groups for those.
if group_column is not None:
# Explicit group column specified - validate and extract
if not hasattr(dataset, 'metadata_columns') or not dataset.metadata_columns:
raise ValueError(
f"Group column '{group_column}' specified but dataset has no metadata columns."
)
if group_column not in dataset.metadata_columns:
raise ValueError(
f"Group column '{group_column}' not found in metadata.\n"
f"Available columns: {dataset.metadata_columns}"
)
# Extract groups from specified column (base samples only)
try:
groups = dataset.metadata_column(group_column, local_context, include_augmented=False)
if len(groups) != X.shape[0]:
raise ValueError(
f"Group array length ({len(groups)}) doesn't match X rows ({X.shape[0]})"
)
except Exception as e:
raise ValueError(
f"Failed to extract groups from metadata column '{group_column}': {e}"
) from e
elif hasattr(dataset, 'metadata_columns') and dataset.metadata_columns:
# No explicit group column, but metadata available - use first column as default
group_column = dataset.metadata_columns[0]
logger.warning(
f"{op.__class__.__name__} has 'groups' parameter but no 'group' specified. "
f"Using default: '{group_column}'"
)
try:
groups = dataset.metadata_column(group_column, local_context, include_augmented=False)
if len(groups) != X.shape[0]:
raise ValueError(
f"Group array length ({len(groups)}) doesn't match X rows ({X.shape[0]})"
)
except Exception as e:
raise ValueError(
f"Failed to extract groups from metadata column '{group_column}': {e}"
) from e
# else: No group column specified and no metadata available
# Leave groups=None and let the splitter handle it
# (will work for splitters that don't require groups, will fail for those that do)
n_samples = X.shape[0]
# Build kwargs for split()
kwargs: Dict[str, Any] = {}
if needs_y or force_group_column is not None or force_group_is_y:
# Provide y for splitters that need it, or for force_group wrapper
if needs_y and y is None:
raise ValueError(
f"{op.__class__.__name__} requires y but dataset.y returned None"
)
if y is not None:
# Special case: force_group="y" with StratifiedKFold
# Pass bin labels (groups) as y for stratification on continuous targets
if force_group_is_y and "Stratified" in step_info.operator.__class__.__name__:
# Use bin labels for stratification instead of continuous y
kwargs["y"] = groups.astype(int)
else:
kwargs["y"] = y
if groups is not None:
# Provide groups for:
# 1. Native group splitters (needs_g is True)
# 2. force_group wrapped splitters (wrapper needs groups)
kwargs["groups"] = groups
# Train mode: perform actual fold splitting
folds = list(op.split(X, **kwargs)) # Convert to list to avoid iterator consumption
# Convert positional indices to absolute sample IDs
# This ensures folds remain valid even if samples are excluded later by sample_filter
sample_id_folds = [
(base_sample_ids[train_idx].tolist(), base_sample_ids[val_idx].tolist())
for train_idx, val_idx in folds
]
# If no test partition exists and this is a single-fold split,
# use the validation set as test partition (not as fold)
# This is expected behavior for single-fold splitters (e.g., SPXYGFold with n_splits=1)
# which are designed to create train/test splits, not cross-validation folds
if dataset.x({"partition": "test"}).shape[0] == 0 and len(sample_id_folds) == 1:
fold_1 = sample_id_folds[0]
if len(fold_1[1]) > 0: # Only if there are validation samples
# Move validation samples to test partition using sample IDs
dataset._indexer.update_by_indices(
fold_1[1], {"partition": "test"}
)
# Keep train sample IDs, clear validation (they're now in test partition)
sample_id_folds = [(fold_1[0], [])]
# Store the folds in the dataset (using sample IDs, not positional indices)
dataset.set_folds(sample_id_folds)
# Generate binary output with fold information (using sample IDs)
headers = [f"fold_{i}" for i in range(len(sample_id_folds))]
binary = ",".join(headers).encode("utf-8") + b"\n"
max_train_samples = max(len(train_idx) for train_idx, _ in sample_id_folds)
for row_idx in range(max_train_samples):
row_values = []
for fold_idx, (train_idx, val_idx) in enumerate(sample_id_folds):
if row_idx < len(train_idx):
row_values.append(str(train_idx[row_idx]))
else:
row_values.append("") # Empty cell if this fold has fewer samples
binary += ",".join(row_values).encode("utf-8") + b"\n"
# Filename includes group column if used
# For force_group, use the inner splitter's name
if force_group_column is not None or force_group_is_y:
inner_splitter = op.splitter # GroupedSplitterWrapper stores inner splitter
folds_name = f"folds_{inner_splitter.__class__.__name__}"
if force_group_is_y:
folds_name += f"_force_group-y_bins{n_bins}"
else:
folds_name += f"_force_group-{force_group_column}"
if aggregation != "mean":
folds_name += f"_{aggregation}"
if hasattr(inner_splitter, "random_state"):
seed = getattr(inner_splitter, "random_state")
if seed is not None:
folds_name += f"_seed{seed}"
else:
folds_name = f"folds_{op.__class__.__name__}"
if group_column:
folds_name += f"_group-{group_column}"
if hasattr(op, "random_state"):
seed = getattr(op, "random_state")
if seed is not None:
folds_name += f"_seed{seed}"
# folds_name += ".csv" # Extension handled by StepOutput tuple
# print(f"Generated {len(folds)} folds.")
# Create StepOutput with the CSV
step_output = StepOutput(
outputs=[(binary, folds_name, "csv")]
)
return context, step_output
# else:
# n_folds = operator.get_n_splits(**kwargs) if hasattr(operator, "get_n_splits") else 1
# dataset.set_folds([(list(range(n_samples)), [])] * n_folds)
# return context, []
def _bin_y_for_groups(self, y, n_bins: int = 5):
"""Bin continuous y values into quantile-based groups for stratified splitting.
This method enables stratification on continuous targets by binning y values
into quantiles. Each bin becomes a "pseudo-group" that ensures samples with
similar y values are kept together during splitting, enabling balanced
distribution of target values across folds.
Parameters
----------
y : array-like of shape (n_samples,)
Continuous target values to bin.
n_bins : int, default=5
Number of quantile bins to create. More bins = finer stratification
but may fail with small datasets. Recommended: 3-10 bins.
Returns
-------
groups : ndarray of shape (n_samples,)
Integer bin labels (0 to n_bins-1) for each sample.
Notes
-----
Uses quantile-based binning (pd.qcut equivalent) to ensure approximately
equal number of samples per bin, regardless of y value distribution.
If there are fewer unique y values than n_bins, reduces to unique value
binning (each unique value = one bin).
Examples
--------
>>> y = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
>>> groups = self._bin_y_for_groups(y, n_bins=5)
>>> # groups will be [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] (approximately)
"""
import numpy as np
y = np.asarray(y).ravel()
n_samples = len(y)
# Handle edge cases
n_unique = len(np.unique(y))
if n_unique <= n_bins:
# Fewer unique values than bins - use unique values as groups
_, groups = np.unique(y, return_inverse=True)
return groups
# Quantile-based binning for balanced bin sizes
# This ensures each bin has approximately equal number of samples
try:
# Compute quantile edges
quantiles = np.linspace(0, 1, n_bins + 1)
bin_edges = np.quantile(y, quantiles)
# Make edges unique to avoid empty bins
bin_edges = np.unique(bin_edges)
# If we lost bins due to non-unique edges, adjust
actual_bins = len(bin_edges) - 1
if actual_bins < 2:
# Fall back to unique value binning
_, groups = np.unique(y, return_inverse=True)
return groups
# Digitize y values into bins (1-indexed, so subtract 1)
# np.digitize returns indices 1 to n_bins, we want 0 to n_bins-1
groups = np.digitize(y, bin_edges[1:-1], right=False)
return groups
except Exception:
# Fallback: equal-width binning
y_min, y_max = y.min(), y.max()
if y_min == y_max:
# All values identical
return np.zeros(n_samples, dtype=int)
bin_width = (y_max - y_min) / n_bins
groups = ((y - y_min) / bin_width).astype(int)
groups = np.clip(groups, 0, n_bins - 1) # Handle edge case of y == y_max
return groups