Source code for nirs4all.controllers.splitters.fold_file_loader

"""
Controller for loading pre-computed fold indices from files.

This module provides the FoldFileLoaderController which loads fold definitions
from previously saved fold files (generated by splitters like KFold, ShuffleSplit, etc.)
or from user-provided fold files.

Supported file formats:
- CSV: nirs4all standard format (fold_0, fold_1, ... columns with sample IDs)
- CSV: Single column format (sample_id, fold columns)
- JSON: List of fold objects with train/val keys
- YAML: Same structure as JSON
- TXT: Simple index lists (one per line)

Example pipeline usage::

    pipeline = [
        MinMaxScaler(),
        {"split": "workspace/runs/my_run/folds_KFold_seed42.csv"},
        {"model": PLSRegression()}
    ]
"""

from __future__ import annotations

import json
import csv
import numpy as np
from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Optional, Union

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] class FoldFileParser: """Utility class for parsing fold files in various formats. Supports multiple fold file formats: - nirs4all CSV: columns `fold_0`, `fold_1`, etc. with sample IDs as rows - Assignment CSV: columns `sample_id`, `fold` assigning each sample to a fold - JSON: List of dicts with `train` and `val` (or `test`) keys - YAML: Same structure as JSON - TXT: Simple format with fold indices Examples: >>> parser = FoldFileParser() >>> folds = parser.parse("folds_KFold.csv") >>> # Returns: [(train_ids, val_ids), (train_ids, val_ids), ...] """ SUPPORTED_EXTENSIONS = {'.csv', '.json', '.yaml', '.yml', '.txt'}
[docs] def parse( self, file_path: Union[str, Path], format: Optional[str] = None ) -> List[Tuple[List[int], List[int]]]: """Parse a fold file and return fold definitions. Args: file_path: Path to the fold file. format: Optional format hint ('csv', 'json', 'yaml', 'txt'). If None, format is auto-detected from extension. Returns: List of (train_indices, val_indices) tuples. Raises: FileNotFoundError: If file doesn't exist. ValueError: If file format is unsupported or content is invalid. """ path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"Fold file not found: {path}") # Determine format if format is None: format = self._detect_format(path) if format == 'csv': return self._parse_csv(path) elif format == 'json': return self._parse_json(path) elif format in ('yaml', 'yml'): return self._parse_yaml(path) elif format == 'txt': return self._parse_txt(path) else: raise ValueError(f"Unsupported fold file format: {format}")
def _detect_format(self, path: Path) -> str: """Detect file format from extension.""" suffix = path.suffix.lower() if suffix == '.csv': return 'csv' elif suffix == '.json': return 'json' elif suffix in ('.yaml', '.yml'): return 'yaml' elif suffix == '.txt': return 'txt' else: raise ValueError( f"Cannot detect fold file format for extension: {suffix}. " f"Supported: {self.SUPPORTED_EXTENSIONS}" ) def _parse_csv(self, path: Path) -> List[Tuple[List[int], List[int]]]: """Parse CSV fold file. Supports two formats: 1. nirs4all format: fold_0, fold_1, ... columns with train sample IDs Validation indices are computed as complement. 2. Assignment format: sample_id, fold columns Args: path: Path to CSV file. Returns: List of (train_indices, val_indices) tuples. """ with open(path, 'r', encoding='utf-8') as f: reader = csv.reader(f) headers = next(reader) # Detect format based on headers if all(h.startswith('fold_') for h in headers): # nirs4all format: fold_0, fold_1, etc. return self._parse_csv_nirs4all_format(headers, reader) elif 'fold' in [h.lower() for h in headers]: # Assignment format: sample_id, fold return self._parse_csv_assignment_format(headers, list(reader)) else: # Try to parse as nirs4all format anyway return self._parse_csv_nirs4all_format(headers, reader) def _parse_csv_nirs4all_format( self, headers: List[str], reader ) -> List[Tuple[List[int], List[int]]]: """Parse nirs4all CSV format. Format: Each column is a fold, rows contain train sample IDs. Validation indices are computed as complement of all other folds. Example: fold_0,fold_1,fold_2 0,3,6 1,4,7 2,5,8 """ n_folds = len(headers) # Collect train indices for each fold fold_train_indices: List[List[int]] = [[] for _ in range(n_folds)] for row in reader: for fold_idx, value in enumerate(row): if value.strip(): try: fold_train_indices[fold_idx].append(int(value.strip())) except ValueError: # Skip non-integer values pass # Compute all sample IDs all_sample_ids = set() for indices in fold_train_indices: all_sample_ids.update(indices) # For each fold, validation = samples NOT in train folds = [] for fold_idx in range(n_folds): train_ids = fold_train_indices[fold_idx] val_ids = [sid for sid in all_sample_ids if sid not in train_ids] folds.append((sorted(train_ids), sorted(val_ids))) return folds def _parse_csv_assignment_format( self, headers: List[str], rows: List[List[str]] ) -> List[Tuple[List[int], List[int]]]: """Parse CSV with fold assignments per sample. Format: sample_id column and fold column. Each unique fold value becomes a validation fold. Example: sample_id,fold 0,0 1,1 2,0 3,1 """ # Find column indices header_lower = [h.lower() for h in headers] fold_col = None sample_col = None for idx, h in enumerate(header_lower): if h == 'fold': fold_col = idx elif h in ('sample_id', 'id', 'index', 'sample'): sample_col = idx if fold_col is None: raise ValueError("CSV assignment format requires 'fold' column") # If no sample_id column, use row index use_row_index = sample_col is None # Group samples by fold fold_to_samples: Dict[int, List[int]] = {} for row_idx, row in enumerate(rows): fold_value = int(row[fold_col].strip()) sample_id = row_idx if use_row_index else int(row[sample_col].strip()) if fold_value not in fold_to_samples: fold_to_samples[fold_value] = [] fold_to_samples[fold_value].append(sample_id) # Convert to train/val format # Each fold: val = samples in this fold, train = all other samples all_samples = set() for samples in fold_to_samples.values(): all_samples.update(samples) folds = [] for fold_idx in sorted(fold_to_samples.keys()): val_ids = fold_to_samples[fold_idx] train_ids = [s for s in all_samples if s not in val_ids] folds.append((sorted(train_ids), sorted(val_ids))) return folds def _parse_json(self, path: Path) -> List[Tuple[List[int], List[int]]]: """Parse JSON fold file. Expected format: [ {"train": [0, 1, 2], "val": [3, 4, 5]}, {"train": [3, 4, 5], "val": [0, 1, 2]} ] """ with open(path, 'r', encoding='utf-8') as f: data = json.load(f) if not isinstance(data, list): raise ValueError("JSON fold file must contain a list of fold objects") folds = [] for fold_obj in data: if not isinstance(fold_obj, dict): raise ValueError("Each fold must be a dict with 'train' and 'val' keys") train = fold_obj.get('train', []) val = fold_obj.get('val', fold_obj.get('test', [])) folds.append((list(train), list(val))) return folds def _parse_yaml(self, path: Path) -> List[Tuple[List[int], List[int]]]: """Parse YAML fold file.""" try: import yaml except ImportError: raise ImportError( "PyYAML is required for parsing YAML fold files. " "Install with: pip install pyyaml" ) with open(path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) if not isinstance(data, list): raise ValueError("YAML fold file must contain a list of fold objects") folds = [] for fold_obj in data: if not isinstance(fold_obj, dict): raise ValueError("Each fold must be a dict with 'train' and 'val' keys") train = fold_obj.get('train', []) val = fold_obj.get('val', fold_obj.get('test', [])) folds.append((list(train), list(val))) return folds def _parse_txt(self, path: Path) -> List[Tuple[List[int], List[int]]]: """Parse TXT fold file. Simple format: one fold per line, comma-separated indices. Odd lines are train, even lines are val. Example: 0,1,2,3,4 5,6,7,8,9 5,6,7,8,9 0,1,2,3,4 """ with open(path, 'r', encoding='utf-8') as f: lines = [line.strip() for line in f if line.strip()] if len(lines) % 2 != 0: raise ValueError( "TXT fold file must have even number of lines " "(alternating train/val)" ) folds = [] for i in range(0, len(lines), 2): train_line = lines[i] val_line = lines[i + 1] train = [int(x.strip()) for x in train_line.split(',') if x.strip()] val = [int(x.strip()) for x in val_line.split(',') if x.strip()] folds.append((train, val)) return folds
[docs] @register_controller class FoldFileLoaderController(OperatorController): """Controller for loading pre-computed fold indices from files. This controller matches pipeline steps where the 'split' keyword is used with a file path (string ending in a supported extension) instead of a splitter object. Examples: >>> # In pipeline >>> {"split": "path/to/folds.csv"} >>> {"split": "workspace/runs/my_run/folds_KFold_seed42.csv"} """ priority = 9 # Higher priority than CrossValidatorController (10)
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """Match steps with 'split' keyword and file path value. Returns True if: - keyword is 'split', AND - operator is a string (file path), AND - path has a supported extension (.csv, .json, .yaml, .yml, .txt) """ if keyword != "split": return False if not isinstance(operator, str): return False # Check if it looks like a file path path = Path(operator) return path.suffix.lower() in FoldFileParser.SUPPORTED_EXTENSIONS
[docs] @classmethod def use_multi_source(cls) -> bool: """Fold loading is a single-source operation.""" return False
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """Fold files should be loaded in prediction mode to set up fold structure.""" return True
[docs] def execute( self, step_info: 'ParsedStep', dataset: "SpectroDataset", context: ExecutionContext, runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", loaded_binaries: Any = None, prediction_store: Any = None ) -> Tuple[ExecutionContext, Any]: """Load folds from file and set them on the dataset. Args: step_info: Parsed step containing the file path. dataset: Dataset to set folds on. context: Current execution context. runtime_context: Runtime context with global settings. source: Source index (unused). mode: Execution mode ("train" or "predict"). loaded_binaries: Pre-loaded binaries (unused). prediction_store: Prediction store (unused). Returns: Tuple of (context, StepOutput). """ from nirs4all.pipeline.execution.result import StepOutput file_path = step_info.operator logger.info(f"Loading folds from file: {file_path}") # Parse the fold file parser = FoldFileParser() try: folds = parser.parse(file_path) except Exception as e: raise ValueError(f"Failed to parse fold file '{file_path}': {e}") from e if not folds: raise ValueError(f"No folds found in file: {file_path}") logger.info(f"Loaded {len(folds)} folds from {file_path}") # Get current dataset sample IDs for validation local_context = context.with_partition("train") base_sample_ids = dataset._indexer.x_indices( local_context.selector, include_augmented=False, include_excluded=False ) base_sample_ids_set = set(base_sample_ids.tolist()) # Validate that fold sample IDs exist in the dataset all_fold_ids = set() for train_ids, val_ids in folds: all_fold_ids.update(train_ids) all_fold_ids.update(val_ids) missing_ids = all_fold_ids - base_sample_ids_set if missing_ids: # Check if this is a mismatch warning or error if len(missing_ids) > len(all_fold_ids) * 0.1: # More than 10% missing raise ValueError( f"Fold file contains {len(missing_ids)} sample IDs not in dataset. " f"Sample IDs in dataset: {len(base_sample_ids_set)}. " f"Missing IDs (first 10): {sorted(list(missing_ids))[:10]}" ) else: logger.warning( f"Fold file contains {len(missing_ids)} sample IDs not in current dataset. " f"These will be filtered out." ) # Filter out missing IDs from folds folds = [ ( [i for i in train_ids if i in base_sample_ids_set], [i for i in val_ids if i in base_sample_ids_set] ) for train_ids, val_ids in folds ] # Handle single-fold case: check if should create train/test split test_data = dataset.x({"partition": "test"}) if isinstance(test_data, list): test_size = sum(arr.shape[0] for arr in test_data) if test_data else 0 else: test_size = test_data.shape[0] if len(folds) == 1 and test_size == 0: train_ids, val_ids = folds[0] if len(val_ids) > 0: # Move validation samples to test partition dataset._indexer.update_by_indices( val_ids, {"partition": "test"} ) logger.info( f"Single fold detected: moved {len(val_ids)} samples to test partition" ) # Update folds to have empty validation (now in test) folds = [(train_ids, [])] # Set the folds on the dataset dataset.set_folds(folds) # Log fold statistics for i, (train_ids, val_ids) in enumerate(folds): logger.debug(f" Fold {i}: train={len(train_ids)}, val={len(val_ids)}") # Create output with fold info step_output = StepOutput( metadata={ "fold_file": str(file_path), "n_folds": len(folds), "fold_sizes": [(len(t), len(v)) for t, v in folds] } ) return context, step_output