Source code for nirs4all.controllers.data.auto_transfer_preproc

"""
Auto Transfer Preprocessing Controller.

This module provides the AutoTransferPreprocessingController which automatically
selects optimal preprocessing for transfer learning scenarios. It uses the
TransferPreprocessingSelector to analyze source and target data and select
preprocessing that minimizes distributional distance while preserving signal.

Usage in pipeline:
    # Standalone operator
    pipeline = [
        {"auto_transfer_preproc": {"preset": "balanced"}},
        "PLSRegressor",
    ]

    # With explicit configuration
    pipeline = [
        {
            "auto_transfer_preproc": {
                "preset": "thorough",
                "source_partition": "train",
                "target_partition": "test",
                "apply_recommendation": True,
            }
        },
        {"model": "PLSRegressor"},
    ]
"""

from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import numpy as np
from copy import deepcopy

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] @register_controller class AutoTransferPreprocessingController(OperatorController): """ Controller for automatic transfer-optimized preprocessing selection. This controller analyzes the distributional distance between source and target datasets and automatically selects preprocessing that best aligns them while preserving predictive information. Configuration options: preset: Preset configuration for the selector. - "fast" (default): Quick evaluation of single preprocessings only - "balanced": Includes stacking evaluation - "thorough": Includes stacking and augmentation - "full": All stages including supervised validation - "exhaustive": Deep analysis for research/benchmarking source_partition: Partition to use as source data ("train" or "test"). Default is "train". target_partition: Partition to use as target data ("train" or "test"). Default is "test". apply_recommendation: Whether to apply the best preprocessing to the dataset. If False, only stores the recommendation in context. Default is True. top_k: Number of top recommendations to apply if using augmentation. Default is 1 (best single preprocessing). use_augmentation: If top_k > 1, whether to use feature augmentation to concatenate outputs. Default is False. n_components: Number of PCA components for metric computation. Default is 10. verbose: Verbosity level (0=silent, 1=progress, 2=detailed). Default is 1. # Stage-specific options (override preset) run_stage2: Enable stacking evaluation. stage2_top_k: Number of top candidates for stacking. run_stage3: Enable augmentation evaluation. run_stage4: Enable supervised validation. Example pipeline configurations: # Simple - use defaults {"auto_transfer_preproc": {}} # With preset {"auto_transfer_preproc": {"preset": "balanced"}} # Full configuration { "auto_transfer_preproc": { "preset": "thorough", "source_partition": "train", "target_partition": "test", "apply_recommendation": True, "top_k": 1, "verbose": 2, } } # Multi-source with augmentation { "auto_transfer_preproc": { "preset": "balanced", "top_k": 3, "use_augmentation": True, } } """ priority = 9 # Higher priority than feature_augmentation (10)
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """Check if step is an auto_transfer_preproc operation.""" return keyword == "auto_transfer_preproc"
[docs] @classmethod def use_multi_source(cls) -> bool: """Supports multi-source datasets.""" return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """ Supports prediction mode for applying saved recommendations. In prediction mode, the controller loads the previously computed preprocessing recommendation and applies it to the new data. """ return True
[docs] def execute( self, step_info: "ParsedStep", dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", loaded_binaries: Optional[List[Tuple[str, Any]]] = None, prediction_store: Optional[Any] = None, ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Execute auto transfer preprocessing selection. In train mode: 1. Extract source and target data from the dataset 2. Run TransferPreprocessingSelector to find best preprocessing 3. Apply the recommended preprocessing if configured 4. Store the recommendation as an artifact In predict mode: 1. Load the saved preprocessing recommendation 2. Apply it to the incoming data Args: step_info: Parsed step containing the auto_transfer_preproc config dataset: SpectroDataset to operate on context: Execution context with selector and metadata runtime_context: Runtime infrastructure (saver, step_number, etc.) source: Source index (-1 for all sources) mode: Execution mode ("train", "predict", "explain") loaded_binaries: Pre-loaded artifacts for predict/explain mode prediction_store: Not used by this controller Returns: Tuple of (updated_context, list_of_artifacts) """ config = self._parse_config(step_info.original_step.get("auto_transfer_preproc", {})) if mode in ["predict", "explain"]: return self._execute_predict_mode( config, dataset, context, runtime_context, source, loaded_binaries ) # Train mode: run transfer selection return self._execute_train_mode( config, dataset, context, runtime_context, source )
def _parse_config(self, config: Any) -> Dict[str, Any]: """ Parse and normalize the auto_transfer_preproc configuration. Args: config: Configuration from the pipeline step (dict, None, or empty). Returns: Normalized configuration dictionary with defaults. """ if config is None: config = {} elif not isinstance(config, dict): config = {} defaults = { "preset": "fast", "source_partition": "train", "target_partition": "test", "apply_recommendation": True, "top_k": 1, "use_augmentation": False, "n_components": 10, "verbose": 1, # Stage overrides (None means use preset defaults) "run_stage2": None, "stage2_top_k": None, "stage2_max_depth": None, "run_stage3": None, "stage3_top_k": None, "run_stage4": None, "stage4_top_k": None, # Metric weights (None means use defaults) "metric_weights": None, # Generator spec (optional) "preprocessing_spec": None, } # Merge with defaults result = {**defaults, **config} return result def _execute_train_mode( self, config: Dict[str, Any], dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Execute in train mode: run selection and apply recommendation. Args: config: Parsed configuration. dataset: SpectroDataset to operate on. context: Execution context. runtime_context: Runtime infrastructure. source: Source index. Returns: Tuple of (updated_context, artifacts). """ from nirs4all.analysis import TransferPreprocessingSelector verbose = config["verbose"] artifacts = [] # Extract source and target data X_source, y_source = self._extract_partition_data( dataset, context, config["source_partition"], source ) X_target, y_target = self._extract_partition_data( dataset, context, config["target_partition"], source ) if verbose >= 1: logger.info("Auto Transfer Preprocessing Selection") logger.info(f" Source: {X_source.shape[0]} samples from '{config['source_partition']}' partition") logger.info(f" Target: {X_target.shape[0]} samples from '{config['target_partition']}' partition") # Build selector kwargs from config selector_kwargs = self._build_selector_kwargs(config) # Run transfer preprocessing selection selector = TransferPreprocessingSelector(**selector_kwargs) results = selector.fit(X_source, X_target, y_source, y_target) # Get recommendation top_k = config["top_k"] use_augmentation = config["use_augmentation"] pipeline_spec = results.to_pipeline_spec( top_k=top_k, use_augmentation=use_augmentation, ) if verbose >= 1: best = results.best logger.success(f"Best recommendation: {best.name}") logger.info(f" Transfer score: {best.transfer_score:.4f}") logger.info(f" Improvement: {best.improvement_pct:.1f}%") if top_k > 1: logger.info(f" Pipeline spec (top {top_k}): {pipeline_spec}") # Store recommendation as artifact recommendation_data = { "pipeline_spec": pipeline_spec, "best_name": results.best.name, "transfer_score": results.best.transfer_score, "improvement_pct": results.best.improvement_pct, "top_k": top_k, "use_augmentation": use_augmentation, "ranking": [r.to_dict() for r in results.top_k(min(5, len(results.ranking)))], } if runtime_context.saver is not None: artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name="transfer_preprocessing_recommendation", obj=recommendation_data, format_hint="json", branch_id=context.selector.branch_id, branch_name=context.selector.branch_name, ) artifacts.append(artifact) # Store full results in context metadata for later use context = context.with_metadata( transfer_preprocessing_results=results, transfer_preprocessing_recommendation=recommendation_data, ) # Apply recommendation if configured if config["apply_recommendation"]: context, apply_artifacts = self._apply_recommendation( pipeline_spec, dataset, context, runtime_context, source ) artifacts.extend(apply_artifacts) return context, artifacts def _execute_predict_mode( self, config: Dict[str, Any], dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, loaded_binaries: Optional[List[Tuple[str, Any]]] = None, ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Execute in predict mode: load and apply saved recommendation. Args: config: Parsed configuration. dataset: SpectroDataset to operate on. context: Execution context. runtime_context: Runtime infrastructure. source: Source index. loaded_binaries: Pre-loaded artifacts (deprecated, use artifact_provider). Returns: Tuple of (updated_context, artifacts). """ verbose = config["verbose"] recommendation_data = None # V3: Try artifact_provider first if runtime_context.artifact_provider is not None: step_index = runtime_context.step_number step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step( step_index, branch_path=context.selector.branch_path ) if step_artifacts: artifacts_dict = dict(step_artifacts) recommendation_data = artifacts_dict.get("transfer_preprocessing_recommendation") if recommendation_data is None: raise ValueError( "transfer_preprocessing_recommendation not found. " "Ensure the model was trained with auto_transfer_preproc." ) pipeline_spec = recommendation_data["pipeline_spec"] if verbose >= 1: logger.info("Loading saved transfer preprocessing recommendation") logger.info(f" Best: {recommendation_data['best_name']}") logger.info(f" Pipeline spec: {pipeline_spec}") # Apply recommendation if config["apply_recommendation"]: context, artifacts = self._apply_recommendation( pipeline_spec, dataset, context, runtime_context, source, mode="predict" ) else: artifacts = [] return context, artifacts def _build_selector_kwargs(self, config: Dict[str, Any]) -> Dict[str, Any]: """ Build kwargs for TransferPreprocessingSelector from config. Args: config: Parsed configuration. Returns: Dictionary of kwargs for the selector. """ kwargs = { "preset": config["preset"], "n_components": config["n_components"], "verbose": config["verbose"], } # Add stage overrides if specified stage_params = [ "run_stage2", "stage2_top_k", "stage2_max_depth", "run_stage3", "stage3_top_k", "run_stage4", "stage4_top_k", ] for param in stage_params: if config.get(param) is not None: kwargs[param] = config[param] # Add metric weights if specified if config.get("metric_weights") is not None: kwargs["metric_weights"] = config["metric_weights"] # Add generator spec if specified if config.get("preprocessing_spec") is not None: kwargs["preprocessing_spec"] = config["preprocessing_spec"] return kwargs def _extract_partition_data( self, dataset: "SpectroDataset", context: "ExecutionContext", partition: str, source: int = -1, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Extract X and y data from a specific partition. Args: dataset: SpectroDataset to extract from. context: Execution context. partition: Partition name ("train" or "test"). source: Source index (-1 for first/all sources). Returns: Tuple of (X, y) where y may be None. """ # Create partition-specific selector partition_context = context.with_partition(partition) selector = partition_context.selector # Get X data (2D array for the specified source) X = dataset.x(selector, layout="2d", concat_source=True) if isinstance(X, list) and len(X) > 0: # If multiple sources, use first or specified source src_idx = 0 if source < 0 else source if src_idx < len(X): X = X[src_idx] else: X = X[0] # Ensure 2D X = np.atleast_2d(X) if X.ndim == 3: # Flatten processings if 3D: (samples, processings, features) -> (samples, processings*features) n_samples = X.shape[0] X = X.reshape(n_samples, -1) # Get y data y = dataset.y(selector) if y is not None: y = np.asarray(y) if y.ndim > 1: y = y.flatten() return X, y def _apply_recommendation( self, pipeline_spec: Any, dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train" ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Apply the recommended preprocessing to the dataset. Delegates to the appropriate controller based on the recommendation type: - Single preprocessing (string): Apply via preprocessing step - List of preprocessings: Apply each sequentially - Augmentation dict: Apply via feature_augmentation Args: pipeline_spec: The preprocessing specification to apply. dataset: SpectroDataset to operate on. context: Execution context. runtime_context: Runtime infrastructure. source: Source index. mode: Execution mode. Returns: Tuple of (updated_context, artifacts). """ from nirs4all.analysis import get_base_preprocessings artifacts = [] verbose = getattr(context.metadata, "verbose", 1) if hasattr(context.metadata, "verbose") else 1 # Get preprocessing transforms preprocessings = get_base_preprocessings() if isinstance(pipeline_spec, str): # Single preprocessing or stacked (e.g., "snv" or "snv>d1") if verbose >= 1: logger.info(f" Applying preprocessing: {pipeline_spec}") context, apply_artifacts = self._apply_stacked_preprocessing( pipeline_spec, preprocessings, dataset, context, runtime_context, source, mode ) artifacts.extend(apply_artifacts) elif isinstance(pipeline_spec, list): # List of preprocessings to apply sequentially for pp_name in pipeline_spec: if verbose >= 2: logger.debug(f" Applying preprocessing: {pp_name}") context, apply_artifacts = self._apply_stacked_preprocessing( pp_name, preprocessings, dataset, context, runtime_context, source, mode ) artifacts.extend(apply_artifacts) elif isinstance(pipeline_spec, dict) and "feature_augmentation" in pipeline_spec: # Feature augmentation - delegate to feature_augmentation controller if verbose >= 1: pp_list = pipeline_spec["feature_augmentation"] logger.info(f" Applying feature augmentation: {pp_list}") context, apply_artifacts = self._apply_feature_augmentation( pipeline_spec["feature_augmentation"], preprocessings, dataset, context, runtime_context, source, mode, ) artifacts.extend(apply_artifacts) return context, artifacts def _apply_stacked_preprocessing( self, pp_name: str, preprocessings: Dict[str, Any], dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Apply a stacked preprocessing (e.g., "snv>d1") to the dataset. Args: pp_name: Preprocessing name (may include ">" for stacking). preprocessings: Dictionary of available transforms. dataset: SpectroDataset to operate on. context: Execution context. runtime_context: Runtime infrastructure. source: Source index. mode: Execution mode. Returns: Tuple of (updated_context, artifacts). """ artifacts = [] components = pp_name.split(">") # Get all source indices to process n_sources = dataset.features_sources() source_indices = [source] if source >= 0 else list(range(n_sources)) for sd_idx in source_indices: # Get current processing names processing_ids = list(dataset.features_processings(sd_idx)) # Get data for this source train_context = context.with_partition("train") train_data = dataset.x(train_context.selector, "3d", concat_source=False) all_data = dataset.x(context.selector, "3d", concat_source=False) if isinstance(train_data, list): train_data = train_data[sd_idx] if isinstance(all_data, list): all_data = all_data[sd_idx] # Process each current processing for proc_idx, proc_name in enumerate(processing_ids): # Extract 2D slice for this processing train_2d = train_data[:, proc_idx, :] all_2d = all_data[:, proc_idx, :] # Apply stacked preprocessing current_train = train_2d current_all = all_2d for comp_idx, comp_name in enumerate(components): if comp_name not in preprocessings: raise ValueError(f"Unknown preprocessing: {comp_name}") transform = deepcopy(preprocessings[comp_name]) transform.fit(current_train) current_train = transform.transform(current_train) current_all = transform.transform(current_all) # Save artifact in train mode if mode == "train" and runtime_context.saver is not None: binary_key = f"transfer_pp_{sd_idx}_{proc_name}_{comp_idx}_{comp_name}" artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name=binary_key, obj=transform, format_hint="sklearn", branch_id=context.selector.branch_id, branch_name=context.selector.branch_name, ) artifacts.append(artifact) # Update dataset with transformed features new_proc_name = f"{proc_name}_{pp_name.replace('>', '_')}" dataset.replace_features( source_processings=[proc_name], features=[current_all], processings=[new_proc_name], source=sd_idx, ) # Update context with new processing names new_processing = [] for sd_idx in range(dataset.features_sources()): src_processing = list(dataset.features_processings(sd_idx)) new_processing.append(src_processing) context = context.with_processing(new_processing) return context, artifacts def _apply_feature_augmentation( self, pp_list: List[str], preprocessings: Dict[str, Any], dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", ) -> Tuple["ExecutionContext", List[Tuple[str, Any]]]: """ Apply feature augmentation (concatenate multiple preprocessing outputs). This creates new feature processings by applying each preprocessing and concatenating the results horizontally. Args: pp_list: List of preprocessing names to apply and concatenate. preprocessings: Dictionary of available transforms. dataset: SpectroDataset to operate on. context: Execution context. runtime_context: Runtime infrastructure. source: Source index. mode: Execution mode. Returns: Tuple of (updated_context, artifacts). """ artifacts = [] # Set add_feature mode in context context = context.with_metadata(add_feature=True) # Get all source indices to process n_sources = dataset.features_sources() source_indices = [source] if source >= 0 else list(range(n_sources)) for sd_idx in source_indices: # Get data for this source train_context = context.with_partition("train") train_data = dataset.x(train_context.selector, "3d", concat_source=False) all_data = dataset.x(context.selector, "3d", concat_source=False) if isinstance(train_data, list): train_data = train_data[sd_idx] if isinstance(all_data, list): all_data = all_data[sd_idx] # Get base 2D data (first processing) base_train_2d = train_data[:, 0, :] base_all_2d = all_data[:, 0, :] # Apply each preprocessing and add as new feature processing for pp_name in pp_list: components = pp_name.split(">") current_train = base_train_2d current_all = base_all_2d for comp_idx, comp_name in enumerate(components): if comp_name not in preprocessings: raise ValueError(f"Unknown preprocessing: {comp_name}") transform = deepcopy(preprocessings[comp_name]) transform.fit(current_train) current_train = transform.transform(current_train) current_all = transform.transform(current_all) # Save artifact in train mode if mode == "train" and runtime_context.saver is not None: binary_key = f"transfer_aug_{sd_idx}_{pp_name}_{comp_idx}_{comp_name}" artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name=binary_key, obj=transform, format_hint="sklearn", branch_id=context.selector.branch_id, branch_name=context.selector.branch_name, ) artifacts.append(artifact) # Add as new processing new_proc_name = pp_name.replace(">", "_") dataset.update_features( source_processings=[""], # Empty string means add new features=[current_all], processings=[new_proc_name], source=sd_idx, ) # Update context with new processing names new_processing = [] for sd_idx in range(dataset.features_sources()): src_processing = list(dataset.features_processings(sd_idx)) new_processing.append(src_processing) context = context.with_processing(new_processing) return context, artifacts