Source code for nirs4all.operators.data.repetition

"""Repetition transformation operator configuration.

This module provides configuration dataclasses for transforming spectral
repetitions (multiple spectra per sample) into either separate sources
or additional preprocessings.

When samples have multiple repetitions (e.g., 4 spectra per leaf sample),
these operators reshape the dataset structure:

- `rep_to_sources`: Each repetition becomes a separate data source
  Input: 1 source × (120 samples, 1 pp, 500 features)
  Output: 4 sources × (30 samples, 1 pp, 500 features)

- `rep_to_pp`: Repetitions become additional preprocessing slots
  Input: 1 source × (120 samples, 1 pp, 500 features)
  Output: 1 source × (30 samples, 4 pp, 500 features)

Example:
    >>> # Transform 4 repetitions per sample into 4 sources
    >>> {"rep_to_sources": "Sample_ID"}
    >>>
    >>> # Transform repetitions into preprocessing dimension
    >>> {"rep_to_pp": "Sample_ID"}
    >>>
    >>> # Group by target value instead of metadata column
    >>> {"rep_to_sources": "y"}
    >>>
    >>> # Advanced configuration with options
    >>> {"rep_to_sources": {
    ...     "column": "Sample_ID",
    ...     "on_unequal": "drop",
    ...     "source_names": "rep_{i}"
    ... }}
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import warnings


[docs] class UnequelRepsStrategy(Enum): """Strategy for handling samples with unequal repetition counts. When samples have different numbers of repetitions, this controls how the transformation handles the mismatch. Attributes: ERROR: Raise an error if repetition counts differ (default, strictest). PAD: Pad shorter groups with NaN/zeros to match the longest. DROP: Drop samples that don't have the expected repetition count. TRUNCATE: Truncate all groups to the minimum repetition count. """ ERROR = "error" PAD = "pad" DROP = "drop" TRUNCATE = "truncate"
[docs] @dataclass class RepetitionConfig: """Configuration for repetition transformation operations. This dataclass provides configuration for `rep_to_sources` and `rep_to_pp` keywords, which reshape datasets based on sample repetitions. Repetitions are identified by a metadata column (e.g., "Sample_ID") that groups multiple spectra belonging to the same physical sample. Attributes: column: Metadata column identifying sample groups, or special values: - None (default): Use dataset's `aggregate` column from DatasetConfigs - "y": Group by target values - str: Explicit metadata column name on_unequal: Strategy when samples have different repetition counts. - "error" (default): Raise error if counts differ - "pad": Pad shorter groups with NaN to match longest - "drop": Drop samples without expected repetition count - "truncate": Use minimum count across all samples expected_reps: Expected number of repetitions per sample. If None (default), inferred from data (mode of group sizes). If specified, validates all groups match this count. source_names: Naming template for new sources (rep_to_sources only). - None (default): Uses "rep_0", "rep_1", etc. - str with {i}: Template like "rep_{i}" or "spectrum_{i}" - List[str]: Explicit names for each repetition pp_names: Naming template for new preprocessings (rep_to_pp only). - None (default): Uses "{original}_rep{i}" format - str with {i} and {pp}: Template like "{pp}_r{i}" - List[str]: Explicit names (length = n_reps * n_existing_pp) preserve_order: Whether to preserve sample order within groups. If True (default), repetitions are ordered by their row position. If False, order within groups is undefined. aggregate_metadata: How to handle metadata after grouping. - "first" (default): Keep metadata from first repetition - "validate": Ensure all reps have identical metadata, error if not - "drop": Remove metadata columns that differ across repetitions Example: >>> # Use dataset's aggregate column (simplest) >>> RepetitionConfig() >>> >>> # Simple column-based grouping >>> RepetitionConfig(column="Sample_ID") >>> >>> # Group by target value with padding >>> RepetitionConfig(column="y", on_unequal="pad") >>> >>> # Explicit repetition count validation >>> RepetitionConfig( ... column="Leaf_ID", ... expected_reps=4, ... on_unequal="error" ... ) >>> >>> # Custom source naming >>> RepetitionConfig( ... column="Sample_ID", ... source_names="measurement_{i}" ... ) """ column: Optional[str] = None # None = use dataset.aggregate on_unequal: str = "error" expected_reps: Optional[int] = None source_names: Optional[Union[str, List[str]]] = None pp_names: Optional[Union[str, List[str]]] = None preserve_order: bool = True aggregate_metadata: str = "first"
[docs] def __post_init__(self): """Validate configuration after initialization.""" # Note: column=None is valid (means use dataset.aggregate at runtime) # Validate on_unequal valid_strategies = ("error", "pad", "drop", "truncate") if self.on_unequal not in valid_strategies: raise ValueError( f"on_unequal must be one of {valid_strategies}, got '{self.on_unequal}'" ) # Validate expected_reps if self.expected_reps is not None: if not isinstance(self.expected_reps, int) or self.expected_reps < 1: raise ValueError( f"expected_reps must be a positive integer, got {self.expected_reps}" ) # Validate aggregate_metadata valid_agg = ("first", "validate", "drop") if self.aggregate_metadata not in valid_agg: raise ValueError( f"aggregate_metadata must be one of {valid_agg}, got '{self.aggregate_metadata}'" ) # Validate source_names format if isinstance(self.source_names, str): if "{i}" not in self.source_names: warnings.warn( f"source_names template '{self.source_names}' does not contain '{{i}}'. " "All sources will have the same name. Consider using 'rep_{{i}}' format.", UserWarning, stacklevel=2 ) # Validate pp_names format if isinstance(self.pp_names, str): if "{i}" not in self.pp_names and "{pp}" not in self.pp_names: warnings.warn( f"pp_names template '{self.pp_names}' does not contain '{{i}}' or '{{pp}}'. " "All preprocessings will have the same name.", UserWarning, stacklevel=2 )
@property def is_y_grouping(self) -> bool: """Check if grouping by target values. Returns: True if column is "y" (case-insensitive). """ return self.column is not None and self.column.lower() == "y" @property def uses_dataset_aggregate(self) -> bool: """Check if using dataset's aggregate column. Returns: True if column is None (will use dataset.aggregate at runtime). """ return self.column is None
[docs] def resolve_column(self, dataset_aggregate: Optional[str]) -> str: """Resolve the actual column to use for grouping. Args: dataset_aggregate: The aggregate value from dataset (column name, "y", or None). Returns: The resolved column name to use. Raises: ValueError: If no column specified and dataset has no aggregate setting. """ if self.column is not None: return self.column # Use dataset's aggregate setting if dataset_aggregate is None: raise ValueError( "No column specified for repetition grouping and dataset has no " "aggregate setting. Either specify a column in the step " "(e.g., {'rep_to_sources': 'Sample_ID'}) or set aggregate in " "DatasetConfigs. [Error: REP-E001]" ) return dataset_aggregate
[docs] def get_unequal_strategy(self) -> UnequelRepsStrategy: """Get the unequal handling strategy as an enum. Returns: UnequelRepsStrategy enum value. """ return UnequelRepsStrategy(self.on_unequal)
[docs] def get_source_name(self, rep_index: int) -> str: """Generate source name for a given repetition index. Args: rep_index: Zero-based repetition index. Returns: Source name string. """ if self.source_names is None: return f"rep_{rep_index}" elif isinstance(self.source_names, str): return self.source_names.format(i=rep_index) elif isinstance(self.source_names, list): if rep_index < len(self.source_names): return self.source_names[rep_index] return f"rep_{rep_index}" return f"rep_{rep_index}"
[docs] def get_pp_name(self, rep_index: int, original_pp: str) -> str: """Generate preprocessing name for a given repetition and original processing. Args: rep_index: Zero-based repetition index. original_pp: Original preprocessing name (e.g., "raw", "snv"). Returns: New preprocessing name string. """ if self.pp_names is None: return f"{original_pp}_rep{rep_index}" elif isinstance(self.pp_names, str): return self.pp_names.format(i=rep_index, pp=original_pp) elif isinstance(self.pp_names, list): # For list, compute flat index # This is complex because we need to know n_existing_pp # Fall back to default naming for lists (handled by caller) return f"{original_pp}_rep{rep_index}" return f"{original_pp}_rep{rep_index}"
[docs] def to_dict(self) -> Dict[str, Any]: """Serialize configuration to dictionary. Returns: Dictionary representation for manifest storage. """ result = { "column": self.column, "on_unequal": self.on_unequal, "preserve_order": self.preserve_order, "aggregate_metadata": self.aggregate_metadata, } if self.expected_reps is not None: result["expected_reps"] = self.expected_reps if self.source_names is not None: result["source_names"] = self.source_names if self.pp_names is not None: result["pp_names"] = self.pp_names return result
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "RepetitionConfig": """Create config from dictionary. Args: data: Dictionary representation. If 'column' is missing, uses None (aggregate). Returns: RepetitionConfig instance. """ return cls( column=data.get("column"), # None if not present = use aggregate on_unequal=data.get("on_unequal", "error"), expected_reps=data.get("expected_reps"), source_names=data.get("source_names"), pp_names=data.get("pp_names"), preserve_order=data.get("preserve_order", True), aggregate_metadata=data.get("aggregate_metadata", "first"), )
[docs] @classmethod def from_step_value(cls, value: Union[str, bool, Dict[str, Any], None]) -> "RepetitionConfig": """Create config from step value (string, bool, or dict). Handles multiple syntax styles: - None or True: Use dataset's aggregate column - str: Explicit column name (or "y" for target grouping) - dict: Full configuration with options Args: value: Step value - column name, True/None for aggregate, or config dict. Returns: RepetitionConfig instance. Example: >>> # Use dataset aggregate (simplest) >>> RepetitionConfig.from_step_value(True) >>> RepetitionConfig.from_step_value(None) >>> >>> # Explicit column >>> RepetitionConfig.from_step_value("Sample_ID") >>> >>> # Advanced syntax >>> RepetitionConfig.from_step_value({ ... "column": "Sample_ID", ... "on_unequal": "drop" ... }) """ if value is None or value is True: return cls(column=None) # Use dataset aggregate elif isinstance(value, str): return cls(column=value) elif isinstance(value, dict): return cls.from_dict(value) elif isinstance(value, cls): return value else: raise ValueError( f"Invalid repetition config type: {type(value).__name__}. " f"Expected string (column name), True (use aggregate), or dict with configuration." )
# Expose for imports __all__ = [ "RepetitionConfig", "UnequelRepsStrategy", ]