"""
Wavelength resampling operators for NIRS spectral data.
This module provides resampling functionality to interpolate spectral data
to new wavelength grids using various scipy interpolation methods.
"""
import numpy as np
from typing import Optional, Union, Tuple, Literal
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted
import warnings
[docs]
class Resampler(TransformerMixin, BaseEstimator):
"""
Resample spectral data to new wavelength grid using interpolation.
This transformer interpolates NIRS spectral data from the original wavelength
grid to a target wavelength grid using scipy interpolation methods.
Parameters
----------
target_wavelengths : array-like
Target wavelengths for resampling. Must be 1D array.
method : str, default='linear'
Interpolation method. Supported methods:
- 'linear': Linear interpolation
- 'nearest': Nearest neighbor interpolation
- 'cubic': Cubic spline interpolation
- 'quadratic': Quadratic spline interpolation
- 'slinear': Linear spline (order 1)
- 'zero': Zero-order spline (piecewise constant)
Future: May support additional scipy methods
crop_range : tuple of (float, float) or None, default=None
Optional (min_wavelength, max_wavelength) to crop original data before resampling.
fill_value : float or 'extrapolate', default=0.0
Value to use for target wavelengths outside the original range.
- float: Use this constant value for extrapolation
- 'extrapolate': Extrapolate using the interpolation method
- 0.0: Default padding with zeros (safe choice)
bounds_error : bool, default=False
If True, raise error when target wavelengths are outside original range.
If False, use fill_value for out-of-bounds points.
copy : bool, default=True
Whether to copy input data or modify in place.
Attributes
----------
original_wavelengths_ : ndarray of shape (n_features,)
Original wavelength grid from fit data
n_features_in_ : int
Number of features (wavelengths) in input data
n_features_out_ : int
Number of features (wavelengths) in output data
interpolator_params_ : dict
Stored interpolation parameters for reconstruction
Examples
--------
>>> from nirs4all.operators.transforms import Resampler
>>> import numpy as np
>>>
>>> # Original data at 1000-2500 nm with 200 points
>>> X = np.random.randn(100, 200)
>>> original_wl = np.linspace(1000, 2500, 200)
>>>
>>> # Resample to 100 evenly-spaced wavelengths
>>> target_wl = np.linspace(1000, 2500, 100)
>>> resampler = Resampler(target_wavelengths=target_wl, method='cubic')
>>> resampler.fit(X, wavelengths=original_wl)
>>> X_resampled = resampler.transform(X)
>>> X_resampled.shape
(100, 100)
Notes
-----
- Wavelengths must be strictly increasing
- Warns if target wavelengths extend beyond original range
- Raises error if no wavelengths overlap between original and target
"""
def __init__(
self,
target_wavelengths: np.ndarray,
method: Literal['linear', 'nearest', 'cubic', 'quadratic', 'slinear', 'zero'] = 'linear',
crop_range: Optional[Tuple[float, float]] = None,
fill_value: Union[float, str] = 0.0,
bounds_error: bool = False,
copy: bool = True
):
self.target_wavelengths = target_wavelengths
self.method = method
self.crop_range = crop_range
self.fill_value = fill_value
self.bounds_error = bounds_error
self.copy = copy
def _validate_wavelengths(self, wavelengths: np.ndarray) -> np.ndarray:
"""Validate and convert wavelengths to 1D float array."""
wavelengths = np.asarray(wavelengths, dtype=float)
if wavelengths.ndim != 1:
raise ValueError(f"Wavelengths must be 1D array, got shape {wavelengths.shape}")
# Check if strictly monotonic (increasing or decreasing)
diffs = np.diff(wavelengths)
is_increasing = np.all(diffs > 0)
is_decreasing = np.all(diffs < 0)
if not (is_increasing or is_decreasing):
# Provide detailed error message
non_monotonic_indices = np.where(diffs[:-1] * diffs[1:] <= 0)[0]
raise ValueError(
f"Wavelengths must be strictly monotonic (all increasing or all decreasing). "
f"Found non-monotonic behavior at indices: {non_monotonic_indices[:5]}. "
f"Wavelengths range: [{wavelengths.min():.2f}, {wavelengths.max():.2f}]. "
f"First 10 values: {wavelengths[:10]}. "
f"First 10 diffs: {diffs[:10]}"
)
return wavelengths
def _check_wavelength_overlap(self, original_wl: np.ndarray, target_wl: np.ndarray):
"""Check overlap between original and target wavelengths."""
orig_min, orig_max = original_wl.min(), original_wl.max()
target_min, target_max = target_wl.min(), target_wl.max()
# Check if there's any overlap
if target_max < orig_min or target_min > orig_max:
raise ValueError(
f"No overlap between original wavelengths ({orig_min:.1f}-{orig_max:.1f}) "
f"and target wavelengths ({target_min:.1f}-{target_max:.1f})"
)
# Warn if target extends beyond original range
if target_min < orig_min or target_max > orig_max:
out_of_bounds = []
if target_min < orig_min:
out_of_bounds.append(f"below {orig_min:.1f}")
if target_max > orig_max:
out_of_bounds.append(f"above {orig_max:.1f}")
warnings.warn(
f"Target wavelengths extend {' and '.join(out_of_bounds)} original range. "
f"Using fill_value={self.fill_value} for extrapolation.",
UserWarning
)
[docs]
def fit(self, X, y=None, wavelengths: Optional[np.ndarray] = None):
"""
Fit the resampler by storing original wavelength grid.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : None
Ignored. Present for API consistency.
wavelengths : array-like of shape (n_features,), optional
Original wavelength grid. If None, will be extracted from dataset headers
by the controller.
Returns
-------
self : Resampler
Fitted resampler.
"""
X = check_array(X, dtype=np.float64, ensure_all_finite='allow-nan', copy=self.copy)
if wavelengths is None:
raise ValueError(
"Wavelengths must be provided to fit(). "
"The controller should extract these from dataset.headers()."
)
# Validate wavelengths
self.original_wavelengths_ = self._validate_wavelengths(wavelengths)
if len(self.original_wavelengths_) != X.shape[1]:
raise ValueError(
f"Number of wavelengths ({len(self.original_wavelengths_)}) "
f"must match number of features ({X.shape[1]})"
)
# Store original number of features before cropping
original_n_features = X.shape[1]
# Apply crop range if specified
if self.crop_range is not None:
crop_min, crop_max = self.crop_range
self.crop_mask_ = (self.original_wavelengths_ >= crop_min) & (self.original_wavelengths_ <= crop_max)
# Store info about whether we need to crop features (done by controller for raw data)
# For preprocessed data, controller will pad with 0 instead
self.wavelengths_after_crop_ = self.original_wavelengths_[self.crop_mask_]
if len(self.wavelengths_after_crop_) == 0:
raise ValueError(f"Crop range {self.crop_range} excludes all wavelengths")
# Use cropped wavelengths for interpolation
self.original_wavelengths_ = self.wavelengths_after_crop_
else:
self.crop_mask_ = None
self.wavelengths_after_crop_ = self.original_wavelengths_
# Validate target wavelengths
target_wl = self._validate_wavelengths(self.target_wavelengths)
# Check for overlap and warn if needed
self._check_wavelength_overlap(self.original_wavelengths_, target_wl)
# Store dimensions
self.n_features_in_ = original_n_features # Original before crop
self.n_features_out_ = len(target_wl)
# Store interpolation parameters for serialization
self.interpolator_params_ = {
'target_wavelengths': target_wl,
'method': self.method,
'fill_value': self.fill_value,
'bounds_error': self.bounds_error
}
return self
[docs]
def get_feature_names_out(self, input_features=None):
"""
Get output feature names (target wavelengths as strings).
Parameters
----------
input_features : array-like of str or None, default=None
Ignored. Present for API consistency.
Returns
-------
feature_names_out : ndarray of str
Target wavelengths as strings.
"""
check_is_fitted(self, 'interpolator_params_')
target_wl = self.interpolator_params_['target_wavelengths']
return np.array([f"{wl:.2f}" for wl in target_wl])
[docs]
def __repr__(self):
"""String representation of the resampler."""
if hasattr(self, 'n_features_in_'):
return (f"Resampler(method='{self.method}', "
f"n_in={self.n_features_in_}, n_out={self.n_features_out_})")
else:
target_len = len(self.target_wavelengths) if hasattr(self.target_wavelengths, '__len__') else '?'
return f"Resampler(method='{self.method}', n_out={target_len}, unfitted)"