Source code for nirs4all.controllers.data.resampler

"""
Controller for wavelength resampling operations.

This controller handles the Resampler operator, extracting wavelengths from
dataset headers and managing the resampling process across multiple sources.
"""

from typing import Any, List, Tuple, Optional, TYPE_CHECKING
import numpy as np

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.operators.transforms.resampler import Resampler

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.pipeline.runner import PipelineRunner
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] @register_controller class ResamplerController(OperatorController): """ Controller for Resampler operators. This controller: 1. Extracts wavelengths from dataset headers 2. Validates that headers are convertible to float (wavelengths in cm-1) 3. Fits the resampler with original wavelengths 4. Transforms all data to the target wavelength grid 5. Updates dataset with new features and headers 6. Supports multi-source datasets with per-source or shared parameters """ priority = 5 # Higher priority than TransformerMixin (10) to match Resampler first
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """Match Resampler objects.""" # Get the actual model object model_obj = None if isinstance(step, dict) and 'model' in step: model_obj = step['model'] elif operator is not None: model_obj = operator else: model_obj = step # Check if it's a Resampler is_resampler_class = hasattr(model_obj, '__class__') and model_obj.__class__.__name__ == 'Resampler' return isinstance(model_obj, Resampler) or is_resampler_class
[docs] @classmethod def use_multi_source(cls) -> bool: """Resampler supports multi-source datasets.""" return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """Resampler supports prediction mode.""" return True
def _extract_wavelengths(self, dataset: 'SpectroDataset', source_idx: int) -> np.ndarray: """ Extract and validate wavelengths from dataset headers. Args: dataset: The spectroscopic dataset source_idx: Index of the data source Returns: Array of wavelengths in cm-1 units Raises: ValueError: If headers cannot be converted to wavelengths """ # Check the header unit header_unit = dataset.header_unit(source_idx) # Resampler requires actual wavelength data (not text, indices, or none) if header_unit in ["text", "none", "index"]: headers = dataset.headers(source_idx) raise ValueError( f"Cannot resample data with header_unit='{header_unit}' for source {source_idx}. " f"Resampler requires numeric wavelength headers (cm-1 or nm). " f"Got headers: {headers[:5]}..." ) # Use the dataset's wavelength conversion methods try: wavelengths = dataset.wavelengths_cm1(source_idx) return wavelengths except (ValueError, TypeError) as e: # Provide helpful error message headers = dataset.headers(source_idx) raise ValueError( f"Failed to extract wavelengths from headers for source {source_idx}. " f"Header unit: {header_unit}. Headers: {headers[:5]}... " f"Error: {str(e)}" ) from e def _get_target_wavelengths_for_source( self, operator: Resampler, source_idx: int, n_sources: int ) -> np.ndarray: """ Get target wavelengths for a specific source. If target_wavelengths is a list of arrays, use per-source targets. Otherwise, use the same targets for all sources. Args: operator: The Resampler instance source_idx: Current source index n_sources: Total number of sources Returns: Target wavelengths for this source """ target_wl = operator.target_wavelengths # Check if it's a list of arrays (per-source targets) if isinstance(target_wl, list): if len(target_wl) != n_sources: raise ValueError( f"If target_wavelengths is a list, it must have {n_sources} elements " f"(one per source), but got {len(target_wl)} elements" ) return np.asarray(target_wl[source_idx]) else: # Same targets for all sources return np.asarray(target_wl)
[docs] def execute( self, step_info: 'ParsedStep', dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: 'RuntimeContext', source: int = -1, mode: str = "train", loaded_binaries: Optional[List[Tuple[str, Any]]] = None, prediction_store: Optional[Any] = None ) -> Tuple['ExecutionContext', List]: """ Execute resampling operation. Args: step_info: Pipeline step configuration dataset: Dataset to operate on context: Pipeline execution context runtime_context: Runtime context source: Data source index (-1 for all sources) mode: Execution mode ("train" or "predict") loaded_binaries: Pre-loaded binary objects for prediction mode prediction_store: External prediction store (unused) Returns: Tuple of (updated_context, fitted_resamplers) """ op = step_info.operator operator_name = op.__class__.__name__ # Get train and all data as lists of 3D arrays (one per source) train_context = context.with_partition("train") train_data = dataset.x(train_context.selector, "3d", concat_source=False) all_data = dataset.x(context.selector, "3d", concat_source=False) # Ensure data is in list format if not isinstance(train_data, list): train_data = [train_data] if not isinstance(all_data, list): all_data = [all_data] n_sources = len(train_data) fitted_resamplers = [] transformed_features_list = [] new_processing_names = [] processing_names = [] new_headers_list = [] # Loop through each data source for sd_idx, (train_x, all_x) in enumerate(zip(train_data, all_data)): # Get processing names for this source processing_ids = dataset.features_processings(sd_idx) source_processings = processing_ids if context.selector.processing: source_processings = context.selector.processing[sd_idx] # Extract wavelengths for this source original_wavelengths = self._extract_wavelengths(dataset, sd_idx) # Get target wavelengths for this source target_wavelengths = self._get_target_wavelengths_for_source( op, sd_idx, n_sources ) source_transformed_features = [] source_new_processing_names = [] source_processing_names = [] source_resamplers = [] # Track resamplers to determine final wavelengths # Loop through each processing in the 3D data (samples, processings, features) for processing_idx in range(train_x.shape[1]): processing_name = processing_ids[processing_idx] if processing_name not in source_processings: continue train_2d = train_x[:, processing_idx, :] # Training data all_2d = all_x[:, processing_idx, :] # All data to transform new_operator_name = f"{operator_name}_{runtime_context.next_op()}" if mode == "predict" or mode == "explain": resampler = None # V3: Use artifact_provider for chain-based loading if runtime_context.artifact_provider is not None: step_index = runtime_context.step_number step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step( step_index, branch_path=context.selector.branch_path, source_index=sd_idx ) # Find artifact by name matching for artifact_id, obj in step_artifacts: if new_operator_name in artifact_id: resampler = obj break if resampler is None: raise ValueError( f"Resampler {new_operator_name} not found at step {runtime_context.step_number}" ) else: # Create new resampler with target wavelengths for this source from sklearn.base import clone resampler = clone(op) resampler.target_wavelengths = target_wavelengths # Fit the resampler with original wavelengths resampler.fit(train_2d, wavelengths=original_wavelengths) # Transform all data transformed_2d = resampler.transform(all_2d) # Apply cropping if needed based on processing type # Raw data: crop features directly using the stored crop mask # Preprocessed data: padding with 0 is already handled by fill_value in interpolation is_raw = processing_name.lower() == "raw" or processing_name.startswith("raw") if is_raw and hasattr(resampler, 'crop_mask_') and resampler.crop_mask_ is not None: # Apply the crop mask to remove features outside the target range from nirs4all.operators.transforms.features import CropTransformer crop_indices = np.where(resampler.crop_mask_)[0] if len(crop_indices) > 0: crop_start = crop_indices[0] crop_end = crop_indices[-1] + 1 cropper = CropTransformer(start=crop_start, end=crop_end) transformed_2d = cropper.transform(transformed_2d) # Store results source_transformed_features.append(transformed_2d) new_processing_name = f"{processing_name}_{new_operator_name}" source_new_processing_names.append(new_processing_name) source_processing_names.append(processing_name) source_resamplers.append(resampler) # Persist fitted resampler using new serializer if mode == "train": artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name=new_operator_name, obj=resampler, format_hint='sklearn', branch_id=context.selector.branch_id, branch_name=context.selector.branch_name ) fitted_resamplers.append(artifact) # Determine final wavelengths for headers # Use the OUTPUT wavelengths (target_wavelengths from interpolator_params_) # NOT the input wavelengths (wavelengths_after_crop_) final_wavelengths = target_wavelengths for resampler in source_resamplers: if hasattr(resampler, 'interpolator_params_') and resampler.interpolator_params_ is not None: final_wavelengths = resampler.interpolator_params_['target_wavelengths'] break new_headers = [f"{wl:.2f}" for wl in final_wavelengths] new_headers_list.append(new_headers) transformed_features_list.append(source_transformed_features) new_processing_names.append(source_new_processing_names) processing_names.append(source_processing_names) # Update dataset with resampled features new_processing_list = list(context.selector.processing) for sd_idx, (source_features, src_new_processing_names, new_headers) in enumerate( zip(transformed_features_list, new_processing_names, new_headers_list) ): # Replace features first (resampling changes the wavelength grid) # Note: When feature count changes, the dataset system will handle it properly dataset.replace_features( source_processings=processing_names[sd_idx], features=source_features, processings=src_new_processing_names, source=sd_idx ) new_processing_list[sd_idx] = src_new_processing_names # Update headers AFTER replacing features (so they don't get reset) # Resampler always outputs wavelengths in cm-1 dataset._features.sources[sd_idx].set_headers(new_headers, unit="cm-1") # noqa: SLF001 if runtime_context.saver.save_artifacts: logger.debug(f"Exporting resampled features for dataset '{dataset.name}', source {sd_idx} to CSV...") logger.debug(dataset.features_processings(sd_idx)) train_context = context.with_partition("train") train_x_full = dataset.x(train_context.selector, "2d", concat_source=True) test_context = context.with_partition("test") test_x_full = dataset.x(test_context.selector, "2d", concat_source=True) # save train and test features to CSV for debugging, create folder if needed import os root_path = runtime_context.saver.base_path os.makedirs(f"{root_path}/{dataset.name}", exist_ok=True) np.savetxt(f"{root_path}/{dataset.name}/Export_X_train.csv", train_x_full, delimiter=",") np.savetxt(f"{root_path}/{dataset.name}/Export_X_test.csv", test_x_full, delimiter=",") context = context.with_processing(new_processing_list) context = context.with_metadata(add_feature=False) return context, fitted_resamplers