Source code for nirs4all.controllers.data.sample_augmentation

from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING
from collections import Counter
import numpy as np  # noqa: F401

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.controllers.data.balancing import BalancingCalculator
from nirs4all.data.binning import BinningCalculator  # noqa: F401 - used in _execute_balanced
from nirs4all.pipeline.config.component_serialization import deserialize_component

logger = get_logger(__name__)

try:
    import joblib  # noqa: F401 - used to check availability
    JOBLIB_AVAILABLE = True
except ImportError:
    JOBLIB_AVAILABLE = False

if TYPE_CHECKING:
    from nirs4all.pipeline.runner import PipelineRunner
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] @register_controller class SampleAugmentationController(OperatorController): """ Sample Augmentation Controller with delegation pattern. This controller orchestrates sample augmentation by: 1. Calculating augmentation distribution (standard or balanced mode) 2. Creating transformer→samples mapping 3. Emitting ONE run_step per transformer with target samples The actual augmentation work is delegated to TransformerMixinController. """ priority = 10
[docs] @staticmethod def normalize_generator_spec(spec: Any) -> Any: """Normalize generator spec for sample_augmentation context. In sample_augmentation context, multi-selection should use combinations by default since the order of transformers doesn't matter. Translates legacy 'size' to 'pick' for explicit semantics. Args: spec: Generator specification (may contain _or_, size, pick, arrange). Returns: Normalized spec with 'size' converted to 'pick' if needed. """ if not isinstance(spec, dict): return spec # If explicit pick/arrange specified, honor it if "pick" in spec or "arrange" in spec: return spec # Convert legacy size to pick (combinations) for sample_augmentation if "size" in spec and "_or_" in spec: result = dict(spec) result["pick"] = result.pop("size") return result return spec
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: return keyword == "sample_augmentation"
[docs] @classmethod def use_multi_source(cls) -> bool: """Check if the operator supports multi-source datasets.""" return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """Sample augmentation only runs during training.""" return False
[docs] def execute( self, step_info: 'ParsedStep', dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: 'RuntimeContext', source: int = -1, mode: str = "train", loaded_binaries: Optional[Any] = None, prediction_store: Optional[Any] = None ) -> Tuple['ExecutionContext', List]: """ Execute sample augmentation with standard or balanced mode. Step format for standard mode: { "sample_augmentation": { "transformers": [transformer1, transformer2, ...], "count": int, "selection": "random" or "all", # Default "random" "random_state": int # Optional } } Step format for balanced mode (choose one balancing strategy): Mode 1 - Fixed target size per class: { "sample_augmentation": { "transformers": [...], "balance": "y" or "metadata_column", # Default "y" "target_size": int, # Fixed target samples per class "selection": "random" or "all", "random_state": int } } Mode 2 - Multiplier for augmentation: { "sample_augmentation": { "transformers": [...], "balance": "y" or "metadata_column", "max_factor": float, # Multiplier (e.g., 3 means class grows 3x) "selection": "random" or "all", "random_state": int } } Mode 3 - Percentage of majority class: { "sample_augmentation": { "transformers": [...], "balance": "y" or "metadata_column", "ref_percentage": float, # Target as % of majority (0.0-1.0) "selection": "random" or "all", "random_state": int } } Binning for regression (automatic when balance="y" and task is regression): { "sample_augmentation": { "transformers": [...], "balance": "y", "bins": int, # Number of virtual classes (default: 10) "binning_strategy": "equal_width" or "quantile", # Default: "equal_width" "max_factor": float, # Choose one balancing mode "selection": "random" or "all", "random_state": int } } """ # Extract step config for compatibility step = step_info.original_step config = step["sample_augmentation"] transformers_raw = config.get("transformers", []) if not transformers_raw: raise ValueError("sample_augmentation requires at least one transformer") # Deserialize transformers (they may be stored as serialized class paths) transformers = [deserialize_component(t) for t in transformers_raw] # Determine mode is_balanced = "balance" in config if is_balanced: return self._execute_balanced(config, transformers, dataset, context, runtime_context, loaded_binaries) else: return self._execute_standard(config, transformers, dataset, context, runtime_context, loaded_binaries)
def _execute_standard( self, config: Dict, transformers: List, dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: 'RuntimeContext', loaded_binaries: Optional[Any] ) -> Tuple['ExecutionContext', List]: """Execute standard count-based augmentation.""" count = config.get("count", 1) selection = config.get("selection", "random") random_state = config.get("random_state", None) # Get train samples (base only, no augmented) train_context = context.with_partition("train") # Get base samples only (exclude augmented) base_samples_idx = dataset._indexer.x_indices(train_context.selector, include_augmented=False) # noqa: SLF001 base_samples = base_samples_idx.tolist() if hasattr(base_samples_idx, 'tolist') else list(base_samples_idx) if not base_samples: return context, [] # Create augmentation plan: sample_id → number of augmentations augmentation_counts = {sample_id: count for sample_id in base_samples} # Build transformer distribution: sample_id → list of transformer indices if selection == "random": transformer_map = BalancingCalculator.apply_random_transformer_selection( transformers, augmentation_counts, random_state ) else: # "all" transformer_map = self._cycle_transformers(transformers, augmentation_counts) # Invert map: transformer_idx → list of sample_ids transformer_to_samples = self._invert_transformer_map(transformer_map, len(transformers)) # Emit ONE run_step per transformer self._emit_augmentation_steps( transformer_to_samples, transformers, context, dataset, runtime_context, loaded_binaries ) return context, [] def _execute_balanced( self, config: Dict, transformers: List, dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: 'RuntimeContext', loaded_binaries: Optional[Any] ) -> Tuple['ExecutionContext', List]: """Execute balanced class-aware augmentation.""" balance_source = config.get("balance", "y") target_size = config.get("target_size", None) max_factor = config.get("max_factor", None) ref_percentage = config.get("ref_percentage", None) if target_size is None and ref_percentage is None and max_factor is None: ref_percentage = 1.0 # Default to ref_percentage=1.0 if none specified selection = config.get("selection", "random") random_state = config.get("random_state", None) bin_balancing = config.get("bin_balancing", "sample") # "sample" or "value" # Get train samples ONLY (ensure we're in train partition) train_context = context.with_partition("train") # train_context.pop("train_indices", None) # Remove any existing indices # train_context.pop("test_indices", None) # Get ALL TRAIN samples (base + augmented) all_train_samples = dataset._indexer.x_indices(train_context.selector, include_augmented=True) # noqa: SLF001 # Get only BASE TRAIN samples (these have actual data to augment) base_train_samples = dataset._indexer.x_indices(train_context.selector, include_augmented=False) # noqa: SLF001 if len(base_train_samples) == 0: return context, [] # Get labels for ALL TRAIN samples (to calculate target size) if balance_source == "y": labels_all_train = dataset.y(train_context.selector, include_augmented=True) # Flatten if necessary labels_all_train = labels_all_train.flatten() if labels_all_train.ndim > 1 else labels_all_train # Store original values before binning (needed for value-aware balancing) original_values_all = labels_all_train.copy() # Apply binning for regression tasks if dataset.is_regression: bins = config.get("bins", 10) strategy = config.get("binning_strategy", "equal_width") labels_all_train, _ = BinningCalculator.bin_continuous_targets( labels_all_train, bins=bins, strategy=strategy ) else: # Metadata column - map augmented samples to origins using y_indices if not isinstance(balance_source, str): raise ValueError(f"balance source must be 'y' or a metadata column name, got {balance_source}") # Get origin indices for all train samples (including augmented mapped to origins) origin_indices = dataset._indexer.y_indices(train_context.selector, include_augmented=True) # noqa: SLF001 # Get base metadata and index into it using origin indices base_metadata = dataset._metadata.get_column(balance_source) # noqa: SLF001 labels_all_train = base_metadata[origin_indices] original_values_all = None # Get labels for BASE TRAIN samples only (for calculating augmentation per base sample) labels_base_train = labels_all_train[:len(base_train_samples)] if original_values_all is not None: original_values_base = original_values_all[:len(base_train_samples)] else: original_values_base = None # Calculate augmentation counts per BASE TRAIN sample using specified mode if bin_balancing == "value" and dataset.is_regression and original_values_base is not None: # Use value-aware balancing for regression with binning augmentation_counts = BalancingCalculator.calculate_balanced_counts_value_aware( labels_base_train, base_train_samples, original_values_base, labels_all_train, all_train_samples, target_size=target_size, max_factor=max_factor, ref_percentage=ref_percentage, random_state=random_state ) else: # Use standard sample-aware balancing augmentation_counts = BalancingCalculator.calculate_balanced_counts( labels_base_train, base_train_samples, labels_all_train, all_train_samples, target_size=target_size, max_factor=max_factor, ref_percentage=ref_percentage, random_state=random_state ) # --- Debug Print --- logger.debug("--- Sample Augmentation Class Distribution ---") logger.debug("Before Augmentation:") before_counts = Counter(labels_all_train) for label, count in sorted(before_counts.items()): logger.debug(f" Class {label}: {count}") logger.debug("Planned Augmentation:") sample_to_label = {sid: lbl for sid, lbl in zip(base_train_samples, labels_base_train)} added_counts = Counter() for sample_id, count in augmentation_counts.items(): if count > 0: lbl = sample_to_label.get(sample_id) if lbl is not None: added_counts[lbl] += count logger.debug("After Augmentation (Expected):") all_labels = set(before_counts.keys()) | set(added_counts.keys()) for label in sorted(all_labels): before = before_counts[label] added = added_counts[label] total = before + added logger.debug(f" Class {label}: {before} + {added} = {total}") logger.debug("----------------------------------------------") # ------------------- # Check if any augmentation is needed if sum(augmentation_counts.values()) == 0: # All classes already balanced, no augmentation needed return context, [] # Build transformer distribution if selection == "random": transformer_map = BalancingCalculator.apply_random_transformer_selection( transformers, augmentation_counts, random_state ) else: transformer_map = self._cycle_transformers(transformers, augmentation_counts) # Invert map: transformer_idx → list of sample_ids transformer_to_samples = self._invert_transformer_map(transformer_map, len(transformers)) # Emit ONE run_step per transformer- self._emit_augmentation_steps( transformer_to_samples, transformers, context, dataset, runtime_context, loaded_binaries ) return context, [] def _invert_transformer_map( self, transformer_map: Dict[int, List[int]], n_transformers: int ) -> Dict[int, List[int]]: """ Invert sample→transformer map to transformer→samples map. Args: transformer_map: {sample_id: [trans_idx1, trans_idx2, ...]} n_transformers: Total number of transformers Returns: {trans_idx: [sample_id1, sample_id2, ...]} """ inverted = {i: [] for i in range(n_transformers)} for sample_id, trans_indices in transformer_map.items(): for trans_idx in trans_indices: inverted[trans_idx].append(sample_id) return inverted def _emit_augmentation_steps( self, transformer_to_samples: Dict[int, List[int]], transformers: List, context: 'ExecutionContext', dataset: 'SpectroDataset', runtime_context: 'RuntimeContext', loaded_binaries: Optional[Any] ): """ Execute transformers and add augmented samples to dataset. This method supports two modes: 1. Parallel mode (when joblib available and n_jobs > 1): Execute transformers in parallel, collect all augmented data, then batch insert. Much faster for many transformers. 2. Sequential mode: Execute transformers one by one (fallback). TransformerMixinController will: 1. Detect augment_sample action 2. Transform all target samples in batch 3. Return augmented data OR add to dataset directly """ # Check if parallel execution is possible and beneficial active_transformers = [(idx, samples) for idx, samples in transformer_to_samples.items() if samples and len(samples) > 0] n_transformers = len(active_transformers) if n_transformers == 0: return # Use parallel execution if joblib available and multiple transformers use_parallel = JOBLIB_AVAILABLE and n_transformers > 1 if use_parallel: self._emit_augmentation_steps_parallel( active_transformers, transformers, context, dataset, runtime_context, loaded_binaries ) else: self._emit_augmentation_steps_sequential( active_transformers, transformers, context, dataset, runtime_context, loaded_binaries ) def _emit_augmentation_steps_sequential( self, active_transformers: List[Tuple[int, List[int]]], transformers: List, context: 'ExecutionContext', dataset: 'SpectroDataset', runtime_context: 'RuntimeContext', loaded_binaries: Optional[Any] ): """Sequential execution of transformers (original implementation).""" for trans_idx, sample_ids in active_transformers: transformer = transformers[trans_idx] # Create context for this transformer's augmentation local_context = context.with_metadata( augment_sample=True, target_samples=sample_ids ).with_partition("train") # ONE run_step per transformer - it handles all target samples if runtime_context.step_runner: runtime_context.substep_number += 1 _ = runtime_context.step_runner.execute( transformer, dataset, local_context, runtime_context, loaded_binaries=loaded_binaries, prediction_store=None ) def _emit_augmentation_steps_parallel( self, active_transformers: List[Tuple[int, List[int]]], transformers: List, context: 'ExecutionContext', dataset: 'SpectroDataset', runtime_context: 'RuntimeContext', loaded_binaries: Optional[Any] ): """ Parallel execution of transformers using joblib. Flow: 1. Fetch train data once (for fitting) and all origin data (for transform) 2. Execute all transformers in parallel, each returning augmented data 3. Collect all results and batch insert into dataset """ from sklearn.base import clone from concurrent.futures import ThreadPoolExecutor, as_completed # Get train data for fitting (once for all transformers) train_context = context.with_partition("train") train_selector = train_context.selector.with_augmented(False) train_data = dataset.x(train_selector, "3d", concat_source=False) if not isinstance(train_data, list): train_data = [train_data] n_sources = len(train_data) n_processings = train_data[0].shape[1] if n_sources > 0 else 0 # Collect all unique sample IDs across all transformers all_sample_ids = set() for _, sample_ids in active_transformers: all_sample_ids.update(sample_ids) all_sample_ids_list = sorted(all_sample_ids) # Batch fetch all origin samples once batch_selector = {"sample": all_sample_ids_list} all_origin_data = dataset.x(batch_selector, "3d", concat_source=False, include_augmented=False) if not isinstance(all_origin_data, list): all_origin_data = [all_origin_data] # Create sample_id to index mapping for efficient lookup sample_id_to_idx = {sid: idx for idx, sid in enumerate(all_sample_ids_list)} # Pre-fit all transformer × source × processing combinations # This can be done in parallel too, but keep it simple for now all_fitted = {} # (trans_idx, source_idx, proc_idx) -> fitted transformer for trans_idx, _ in active_transformers: transformer = transformers[trans_idx] # Check if transformer is actual object or string reference if isinstance(transformer, str): raise ValueError(f"Transformer at index {trans_idx} is a string '{transformer}' instead of an object. " "Ensure transformers are instantiated before passing to sample_augmentation.") for source_idx in range(n_sources): for proc_idx in range(n_processings): cloned = clone(transformer) train_proc = train_data[source_idx][:, proc_idx, :] cloned.fit(train_proc) all_fitted[(trans_idx, source_idx, proc_idx)] = cloned def process_transformer(args): """Process a single transformer and return augmented data + index info.""" trans_idx, sample_ids = args transformer = transformers[trans_idx] operator_name = transformer.__class__.__name__ # Get indices for this transformer's samples local_indices = [sample_id_to_idx[sid] for sid in sample_ids] # Transform all samples for this transformer transformed_per_source = [] for source_idx in range(n_sources): source_origin = all_origin_data[source_idx] # (all_samples, procs, feats) local_source_data = source_origin[local_indices] # (n_local, procs, feats) transformed_procs = [] for proc_idx in range(n_processings): proc_data = local_source_data[:, proc_idx, :] # (n_local, feats) fitted = all_fitted[(trans_idx, source_idx, proc_idx)] transformed = fitted.transform(proc_data) # (n_local, feats) transformed_procs.append(transformed) # Stack processings: (n_local, n_processings, n_features) source_3d = np.stack(transformed_procs, axis=1) transformed_per_source.append(source_3d) # Prepare output data # For multi-source, return list of arrays (one per source) # For single source, return single array if n_sources == 1: batch_data = transformed_per_source[0] # (n_local, n_procs, n_feats) else: # For multi-source, return list of arrays batch_data = transformed_per_source # List of (n_local, n_procs, n_feats) # Build index dictionaries indexes_list = [ {"partition": "train", "origin": sid, "augmentation": operator_name} for sid in sample_ids ] return batch_data, indexes_list # Execute in parallel using ThreadPoolExecutor (no pickling issues) all_batch_data = [] all_indexes = [] max_workers = min(len(active_transformers), 16) # Cap at 16 threads with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {executor.submit(process_transformer, args): args for args in active_transformers} for future in as_completed(futures): batch_data, indexes_list = future.result() all_batch_data.append(batch_data) all_indexes.extend(indexes_list) if not all_batch_data: return # Concatenate all augmented data # Handle both single-source (arrays) and multi-source (list of arrays) if n_sources == 1: # Single source: all_batch_data is list of arrays combined_data = np.concatenate(all_batch_data, axis=0) else: # Multi-source: all_batch_data is list of lists of arrays # Need to concatenate per-source, then return as list combined_data = [] for source_idx in range(n_sources): source_arrays = [batch[source_idx] for batch in all_batch_data] combined_source = np.concatenate(source_arrays, axis=0) combined_data.append(combined_source) # Single batch insert for ALL augmented samples from ALL transformers dataset.add_samples_batch(data=combined_data, indexes_list=all_indexes) def _cycle_transformers( self, transformers: List, augmentation_counts: Dict[int, int] ) -> Dict[int, List[int]]: """Cycle through transformers for 'all' selection mode.""" transformer_map = {} for sample_id, count in augmentation_counts.items(): if count > 0: transformer_map[sample_id] = [i % len(transformers) for i in range(count)] else: transformer_map[sample_id] = [] return transformer_map