"""
Controller for loading pre-computed fold indices from files.
This module provides the FoldFileLoaderController which loads fold definitions
from previously saved fold files (generated by splitters like KFold, ShuffleSplit, etc.)
or from user-provided fold files.
Supported file formats:
- CSV: nirs4all standard format (fold_0, fold_1, ... columns with sample IDs)
- CSV: Single column format (sample_id, fold columns)
- JSON: List of fold objects with train/val keys
- YAML: Same structure as JSON
- TXT: Simple index lists (one per line)
Example pipeline usage::
pipeline = [
MinMaxScaler(),
{"split": "workspace/runs/my_run/folds_KFold_seed42.csv"},
{"model": PLSRegression()}
]
"""
from __future__ import annotations
import json
import csv
import numpy as np
from pathlib import Path
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Optional, Union
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
logger = get_logger(__name__)
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.steps.parser import ParsedStep
[docs]
class FoldFileParser:
"""Utility class for parsing fold files in various formats.
Supports multiple fold file formats:
- nirs4all CSV: columns `fold_0`, `fold_1`, etc. with sample IDs as rows
- Assignment CSV: columns `sample_id`, `fold` assigning each sample to a fold
- JSON: List of dicts with `train` and `val` (or `test`) keys
- YAML: Same structure as JSON
- TXT: Simple format with fold indices
Examples:
>>> parser = FoldFileParser()
>>> folds = parser.parse("folds_KFold.csv")
>>> # Returns: [(train_ids, val_ids), (train_ids, val_ids), ...]
"""
SUPPORTED_EXTENSIONS = {'.csv', '.json', '.yaml', '.yml', '.txt'}
[docs]
def parse(
self,
file_path: Union[str, Path],
format: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
"""Parse a fold file and return fold definitions.
Args:
file_path: Path to the fold file.
format: Optional format hint ('csv', 'json', 'yaml', 'txt').
If None, format is auto-detected from extension.
Returns:
List of (train_indices, val_indices) tuples.
Raises:
FileNotFoundError: If file doesn't exist.
ValueError: If file format is unsupported or content is invalid.
"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Fold file not found: {path}")
# Determine format
if format is None:
format = self._detect_format(path)
if format == 'csv':
return self._parse_csv(path)
elif format == 'json':
return self._parse_json(path)
elif format in ('yaml', 'yml'):
return self._parse_yaml(path)
elif format == 'txt':
return self._parse_txt(path)
else:
raise ValueError(f"Unsupported fold file format: {format}")
def _detect_format(self, path: Path) -> str:
"""Detect file format from extension."""
suffix = path.suffix.lower()
if suffix == '.csv':
return 'csv'
elif suffix == '.json':
return 'json'
elif suffix in ('.yaml', '.yml'):
return 'yaml'
elif suffix == '.txt':
return 'txt'
else:
raise ValueError(
f"Cannot detect fold file format for extension: {suffix}. "
f"Supported: {self.SUPPORTED_EXTENSIONS}"
)
def _parse_csv(self, path: Path) -> List[Tuple[List[int], List[int]]]:
"""Parse CSV fold file.
Supports two formats:
1. nirs4all format: fold_0, fold_1, ... columns with train sample IDs
Validation indices are computed as complement.
2. Assignment format: sample_id, fold columns
Args:
path: Path to CSV file.
Returns:
List of (train_indices, val_indices) tuples.
"""
with open(path, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
headers = next(reader)
# Detect format based on headers
if all(h.startswith('fold_') for h in headers):
# nirs4all format: fold_0, fold_1, etc.
return self._parse_csv_nirs4all_format(headers, reader)
elif 'fold' in [h.lower() for h in headers]:
# Assignment format: sample_id, fold
return self._parse_csv_assignment_format(headers, list(reader))
else:
# Try to parse as nirs4all format anyway
return self._parse_csv_nirs4all_format(headers, reader)
def _parse_csv_nirs4all_format(
self,
headers: List[str],
reader
) -> List[Tuple[List[int], List[int]]]:
"""Parse nirs4all CSV format.
Format: Each column is a fold, rows contain train sample IDs.
Validation indices are computed as complement of all other folds.
Example:
fold_0,fold_1,fold_2
0,3,6
1,4,7
2,5,8
"""
n_folds = len(headers)
# Collect train indices for each fold
fold_train_indices: List[List[int]] = [[] for _ in range(n_folds)]
for row in reader:
for fold_idx, value in enumerate(row):
if value.strip():
try:
fold_train_indices[fold_idx].append(int(value.strip()))
except ValueError:
# Skip non-integer values
pass
# Compute all sample IDs
all_sample_ids = set()
for indices in fold_train_indices:
all_sample_ids.update(indices)
# For each fold, validation = samples NOT in train
folds = []
for fold_idx in range(n_folds):
train_ids = fold_train_indices[fold_idx]
val_ids = [sid for sid in all_sample_ids if sid not in train_ids]
folds.append((sorted(train_ids), sorted(val_ids)))
return folds
def _parse_csv_assignment_format(
self,
headers: List[str],
rows: List[List[str]]
) -> List[Tuple[List[int], List[int]]]:
"""Parse CSV with fold assignments per sample.
Format: sample_id column and fold column.
Each unique fold value becomes a validation fold.
Example:
sample_id,fold
0,0
1,1
2,0
3,1
"""
# Find column indices
header_lower = [h.lower() for h in headers]
fold_col = None
sample_col = None
for idx, h in enumerate(header_lower):
if h == 'fold':
fold_col = idx
elif h in ('sample_id', 'id', 'index', 'sample'):
sample_col = idx
if fold_col is None:
raise ValueError("CSV assignment format requires 'fold' column")
# If no sample_id column, use row index
use_row_index = sample_col is None
# Group samples by fold
fold_to_samples: Dict[int, List[int]] = {}
for row_idx, row in enumerate(rows):
fold_value = int(row[fold_col].strip())
sample_id = row_idx if use_row_index else int(row[sample_col].strip())
if fold_value not in fold_to_samples:
fold_to_samples[fold_value] = []
fold_to_samples[fold_value].append(sample_id)
# Convert to train/val format
# Each fold: val = samples in this fold, train = all other samples
all_samples = set()
for samples in fold_to_samples.values():
all_samples.update(samples)
folds = []
for fold_idx in sorted(fold_to_samples.keys()):
val_ids = fold_to_samples[fold_idx]
train_ids = [s for s in all_samples if s not in val_ids]
folds.append((sorted(train_ids), sorted(val_ids)))
return folds
def _parse_json(self, path: Path) -> List[Tuple[List[int], List[int]]]:
"""Parse JSON fold file.
Expected format:
[
{"train": [0, 1, 2], "val": [3, 4, 5]},
{"train": [3, 4, 5], "val": [0, 1, 2]}
]
"""
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON fold file must contain a list of fold objects")
folds = []
for fold_obj in data:
if not isinstance(fold_obj, dict):
raise ValueError("Each fold must be a dict with 'train' and 'val' keys")
train = fold_obj.get('train', [])
val = fold_obj.get('val', fold_obj.get('test', []))
folds.append((list(train), list(val)))
return folds
def _parse_yaml(self, path: Path) -> List[Tuple[List[int], List[int]]]:
"""Parse YAML fold file."""
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is required for parsing YAML fold files. "
"Install with: pip install pyyaml"
)
with open(path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
if not isinstance(data, list):
raise ValueError("YAML fold file must contain a list of fold objects")
folds = []
for fold_obj in data:
if not isinstance(fold_obj, dict):
raise ValueError("Each fold must be a dict with 'train' and 'val' keys")
train = fold_obj.get('train', [])
val = fold_obj.get('val', fold_obj.get('test', []))
folds.append((list(train), list(val)))
return folds
def _parse_txt(self, path: Path) -> List[Tuple[List[int], List[int]]]:
"""Parse TXT fold file.
Simple format: one fold per line, comma-separated indices.
Odd lines are train, even lines are val.
Example:
0,1,2,3,4
5,6,7,8,9
5,6,7,8,9
0,1,2,3,4
"""
with open(path, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
if len(lines) % 2 != 0:
raise ValueError(
"TXT fold file must have even number of lines "
"(alternating train/val)"
)
folds = []
for i in range(0, len(lines), 2):
train_line = lines[i]
val_line = lines[i + 1]
train = [int(x.strip()) for x in train_line.split(',') if x.strip()]
val = [int(x.strip()) for x in val_line.split(',') if x.strip()]
folds.append((train, val))
return folds
[docs]
@register_controller
class FoldFileLoaderController(OperatorController):
"""Controller for loading pre-computed fold indices from files.
This controller matches pipeline steps where the 'split' keyword is used
with a file path (string ending in a supported extension) instead of
a splitter object.
Examples:
>>> # In pipeline
>>> {"split": "path/to/folds.csv"}
>>> {"split": "workspace/runs/my_run/folds_KFold_seed42.csv"}
"""
priority = 9 # Higher priority than CrossValidatorController (10)
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Match steps with 'split' keyword and file path value.
Returns True if:
- keyword is 'split', AND
- operator is a string (file path), AND
- path has a supported extension (.csv, .json, .yaml, .yml, .txt)
"""
if keyword != "split":
return False
if not isinstance(operator, str):
return False
# Check if it looks like a file path
path = Path(operator)
return path.suffix.lower() in FoldFileParser.SUPPORTED_EXTENSIONS
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Fold loading is a single-source operation."""
return False
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Fold files should be loaded in prediction mode to set up fold structure."""
return True
[docs]
def execute(
self,
step_info: 'ParsedStep',
dataset: "SpectroDataset",
context: ExecutionContext,
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Any = None,
prediction_store: Any = None
) -> Tuple[ExecutionContext, Any]:
"""Load folds from file and set them on the dataset.
Args:
step_info: Parsed step containing the file path.
dataset: Dataset to set folds on.
context: Current execution context.
runtime_context: Runtime context with global settings.
source: Source index (unused).
mode: Execution mode ("train" or "predict").
loaded_binaries: Pre-loaded binaries (unused).
prediction_store: Prediction store (unused).
Returns:
Tuple of (context, StepOutput).
"""
from nirs4all.pipeline.execution.result import StepOutput
file_path = step_info.operator
logger.info(f"Loading folds from file: {file_path}")
# Parse the fold file
parser = FoldFileParser()
try:
folds = parser.parse(file_path)
except Exception as e:
raise ValueError(f"Failed to parse fold file '{file_path}': {e}") from e
if not folds:
raise ValueError(f"No folds found in file: {file_path}")
logger.info(f"Loaded {len(folds)} folds from {file_path}")
# Get current dataset sample IDs for validation
local_context = context.with_partition("train")
base_sample_ids = dataset._indexer.x_indices(
local_context.selector, include_augmented=False, include_excluded=False
)
base_sample_ids_set = set(base_sample_ids.tolist())
# Validate that fold sample IDs exist in the dataset
all_fold_ids = set()
for train_ids, val_ids in folds:
all_fold_ids.update(train_ids)
all_fold_ids.update(val_ids)
missing_ids = all_fold_ids - base_sample_ids_set
if missing_ids:
# Check if this is a mismatch warning or error
if len(missing_ids) > len(all_fold_ids) * 0.1: # More than 10% missing
raise ValueError(
f"Fold file contains {len(missing_ids)} sample IDs not in dataset. "
f"Sample IDs in dataset: {len(base_sample_ids_set)}. "
f"Missing IDs (first 10): {sorted(list(missing_ids))[:10]}"
)
else:
logger.warning(
f"Fold file contains {len(missing_ids)} sample IDs not in current dataset. "
f"These will be filtered out."
)
# Filter out missing IDs from folds
folds = [
(
[i for i in train_ids if i in base_sample_ids_set],
[i for i in val_ids if i in base_sample_ids_set]
)
for train_ids, val_ids in folds
]
# Handle single-fold case: check if should create train/test split
test_data = dataset.x({"partition": "test"})
if isinstance(test_data, list):
test_size = sum(arr.shape[0] for arr in test_data) if test_data else 0
else:
test_size = test_data.shape[0]
if len(folds) == 1 and test_size == 0:
train_ids, val_ids = folds[0]
if len(val_ids) > 0:
# Move validation samples to test partition
dataset._indexer.update_by_indices(
val_ids, {"partition": "test"}
)
logger.info(
f"Single fold detected: moved {len(val_ids)} samples to test partition"
)
# Update folds to have empty validation (now in test)
folds = [(train_ids, [])]
# Set the folds on the dataset
dataset.set_folds(folds)
# Log fold statistics
for i, (train_ids, val_ids) in enumerate(folds):
logger.debug(f" Fold {i}: train={len(train_ids)}, val={len(val_ids)}")
# Create output with fold info
step_output = StepOutput(
metadata={
"fold_file": str(file_path),
"n_folds": len(folds),
"fold_sizes": [(len(t), len(v)) for t, v in folds]
}
)
return context, step_output