Source code for nirs4all.data.synthetic.reconstruction.forward

"""
Forward model components for physical signal-chain reconstruction.

This module implements the forward measurement chain:
    Canonical physical model → Instrument effects → Domain transform → Preprocessing

Key principle: Keep one latent physical model on a canonical grid, then apply
dataset-specific transforms to match observed data.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
from scipy.ndimage import gaussian_filter1d
from scipy.signal import savgol_filter
from scipy.interpolate import interp1d

if TYPE_CHECKING:
    from .environmental import EnvironmentalEffectsModel


# =============================================================================
# Canonical Forward Model
# =============================================================================


[docs] @dataclass class CanonicalForwardModel: """ Physical model on canonical high-resolution wavelength grid. Computes absorption coefficient K(λ) from chemical components: K(λ) = Σ c_k * ε_k(λ) + K0(λ) where: - c_k: concentration of component k - ε_k(λ): molar absorptivity (from component library) - K0(λ): continuum/background absorption (low-frequency) Attributes: canonical_grid: High-resolution wavelength grid (nm). component_names: Names of components to include. component_spectra: Pre-computed component spectra on canonical grid. baseline_order: Order of Chebyshev baseline polynomial. continuum_order: Order of continuum absorption polynomial. """ canonical_grid: np.ndarray component_names: List[str] = field(default_factory=list) baseline_order: int = 5 continuum_order: int = 3 _component_spectra: Optional[np.ndarray] = field(default=None, repr=False) _baseline_basis: Optional[np.ndarray] = field(default=None, repr=False) _continuum_basis: Optional[np.ndarray] = field(default=None, repr=False)
[docs] def __post_init__(self): """Initialize component spectra and basis matrices.""" self._build_component_spectra() self._build_basis_matrices()
def _build_component_spectra(self) -> None: """Pre-compute component spectra on canonical grid.""" from ..components import get_component n_wl = len(self.canonical_grid) n_comp = len(self.component_names) if n_comp == 0: self._component_spectra = np.zeros((0, n_wl)) return self._component_spectra = np.zeros((n_comp, n_wl)) for k, name in enumerate(self.component_names): try: comp = get_component(name) self._component_spectra[k] = comp.compute(self.canonical_grid) except (ValueError, KeyError): # Component not found, leave as zeros pass def _build_basis_matrices(self) -> None: """Build Chebyshev polynomial basis matrices.""" n_wl = len(self.canonical_grid) wl_norm = self._normalize_wavelengths(self.canonical_grid) # Baseline basis (Chebyshev polynomials) self._baseline_basis = np.zeros((self.baseline_order + 1, n_wl)) for i in range(self.baseline_order + 1): self._baseline_basis[i] = np.polynomial.chebyshev.chebval( wl_norm, [0] * i + [1] ) # Continuum basis (lower order for smooth background) self._continuum_basis = np.zeros((self.continuum_order + 1, n_wl)) for i in range(self.continuum_order + 1): self._continuum_basis[i] = np.polynomial.chebyshev.chebval( wl_norm, [0] * i + [1] ) def _normalize_wavelengths(self, wl: np.ndarray) -> np.ndarray: """Normalize wavelengths to [-1, 1] for Chebyshev basis.""" wl_min, wl_max = self.canonical_grid.min(), self.canonical_grid.max() return 2 * (wl - wl_min) / (wl_max - wl_min) - 1
[docs] def compute_absorption( self, concentrations: np.ndarray, path_length: float = 1.0, baseline_coeffs: Optional[np.ndarray] = None, continuum_coeffs: Optional[np.ndarray] = None, ) -> np.ndarray: """ Compute absorption coefficient on canonical grid. Args: concentrations: Component concentrations, shape (n_components,). path_length: Optical path length factor. baseline_coeffs: Baseline polynomial coefficients. continuum_coeffs: Continuum absorption coefficients. Returns: Absorbance spectrum on canonical grid. """ n_wl = len(self.canonical_grid) # Component contribution: A = L * Σ c_k * ε_k(λ) if self._component_spectra is not None and len(concentrations) > 0: absorption = path_length * (concentrations @ self._component_spectra) else: absorption = np.zeros(n_wl) # Add baseline if baseline_coeffs is not None and self._baseline_basis is not None: absorption += baseline_coeffs @ self._baseline_basis # Add continuum absorption if continuum_coeffs is not None and self._continuum_basis is not None: absorption += continuum_coeffs @ self._continuum_basis return absorption
[docs] def get_design_matrix(self, path_length: float = 1.0) -> np.ndarray: """ Get full design matrix for linear fitting. Returns: Design matrix of shape (n_wavelengths, n_components + n_baseline + n_continuum). """ matrices = [] # Component spectra (scaled by path length) if self._component_spectra is not None and self._component_spectra.shape[0] > 0: matrices.append(path_length * self._component_spectra.T) # Baseline basis if self._baseline_basis is not None: matrices.append(self._baseline_basis.T) # Continuum basis if self._continuum_basis is not None: matrices.append(self._continuum_basis.T) if matrices: return np.hstack(matrices) return np.zeros((len(self.canonical_grid), 0))
@property def n_components(self) -> int: """Number of chemical components.""" return len(self.component_names) @property def n_baseline(self) -> int: """Number of baseline coefficients.""" return self.baseline_order + 1 @property def n_continuum(self) -> int: """Number of continuum coefficients.""" return self.continuum_order + 1 @property def n_linear_params(self) -> int: """Total number of linear parameters.""" return self.n_components + self.n_baseline + self.n_continuum
# ============================================================================= # Instrument Model # =============================================================================
[docs] @dataclass class InstrumentModel: """ Instrument effects: warp, ILS convolution, gain/offset, resampling. Transforms spectrum from canonical grid to target instrument grid: 1. Wavelength warp: λ* → λ' (shift + stretch + optional higher order) 2. ILS convolution: Gaussian or Voigt line shape 3. Stray light / gain / offset 4. Resample to target grid Attributes: target_grid: Target wavelength grid (dataset grid). wl_shift: Wavelength shift in nm (default 0). wl_stretch: Wavelength scale factor (default 1). wl_poly_coeffs: Higher-order polynomial warp coefficients. ils_sigma: Instrument line shape Gaussian sigma in nm. stray_light: Stray light fraction (default 0). gain: Photometric gain (default 1). offset: Photometric offset (default 0). """ target_grid: np.ndarray wl_shift: float = 0.0 wl_stretch: float = 1.0 wl_poly_coeffs: Optional[np.ndarray] = None ils_sigma: float = 4.0 stray_light: float = 0.0 gain: float = 1.0 offset: float = 0.0
[docs] def apply( self, spectrum: np.ndarray, canonical_grid: np.ndarray, ) -> np.ndarray: """ Apply instrument chain to transform spectrum. Args: spectrum: Input spectrum on canonical grid. canonical_grid: Canonical wavelength grid. Returns: Transformed spectrum on target grid. """ # 1. Wavelength warp warped_wl = self._apply_wavelength_warp(canonical_grid) # 2. ILS convolution wl_step = np.median(np.diff(canonical_grid)) if self.ils_sigma > 0 and wl_step > 0: sigma_idx = self.ils_sigma / wl_step spectrum_ils = gaussian_filter1d(spectrum, sigma=max(0.5, sigma_idx)) else: spectrum_ils = spectrum # 3. Stray light / gain / offset spectrum_phot = self.gain * spectrum_ils + self.offset if self.stray_light > 0: spectrum_phot = spectrum_phot + self.stray_light * np.mean(spectrum_ils) # 4. Resample to target grid # Use linear interpolation with extrapolation handling valid_mask = (warped_wl >= canonical_grid.min()) & (warped_wl <= canonical_grid.max()) if not np.all(valid_mask): # Extend with edge values for extrapolation spectrum_resampled = np.interp( self.target_grid, warped_wl, spectrum_phot, left=spectrum_phot[0], right=spectrum_phot[-1], ) else: spectrum_resampled = np.interp(self.target_grid, warped_wl, spectrum_phot) return spectrum_resampled
def _apply_wavelength_warp(self, canonical_grid: np.ndarray) -> np.ndarray: """Apply wavelength warp transformation.""" warped = self.wl_shift + self.wl_stretch * canonical_grid if self.wl_poly_coeffs is not None and len(self.wl_poly_coeffs) > 0: # Normalize grid for polynomial wl_norm = (canonical_grid - canonical_grid.mean()) / ( canonical_grid.max() - canonical_grid.min() ) for i, coeff in enumerate(self.wl_poly_coeffs): warped += coeff * wl_norm ** (i + 2) # Start from quadratic return warped
[docs] def get_jacobian_wrt_wl_shift( self, spectrum: np.ndarray, canonical_grid: np.ndarray, eps: float = 0.1, ) -> np.ndarray: """Numerical Jacobian w.r.t. wavelength shift.""" orig = self.wl_shift self.wl_shift = orig + eps spec_plus = self.apply(spectrum, canonical_grid) self.wl_shift = orig - eps spec_minus = self.apply(spectrum, canonical_grid) self.wl_shift = orig return (spec_plus - spec_minus) / (2 * eps)
[docs] def get_jacobian_wrt_ils_sigma( self, spectrum: np.ndarray, canonical_grid: np.ndarray, eps: float = 0.1, ) -> np.ndarray: """Numerical Jacobian w.r.t. ILS sigma.""" orig = self.ils_sigma self.ils_sigma = orig + eps spec_plus = self.apply(spectrum, canonical_grid) self.ils_sigma = max(0.5, orig - eps) spec_minus = self.apply(spectrum, canonical_grid) self.ils_sigma = orig return (spec_plus - spec_minus) / (2 * eps)
[docs] @classmethod def from_params( cls, target_grid: np.ndarray, params: Dict[str, float], ) -> "InstrumentModel": """Create InstrumentModel from parameter dictionary.""" return cls( target_grid=target_grid, wl_shift=params.get("wl_shift", 0.0), wl_stretch=params.get("wl_stretch", 1.0), ils_sigma=params.get("ils_sigma", 4.0), stray_light=params.get("stray_light", 0.0), gain=params.get("gain", 1.0), offset=params.get("offset", 0.0), )
# ============================================================================= # Domain Transform # =============================================================================
[docs] @dataclass class DomainTransform: """ Transform between physical domains (absorbance, reflectance, etc.). For absorbance datasets: A(λ) = absorption coefficient (direct) For reflectance datasets: R(λ) computed via Kubelka-Munk or approximation Attributes: domain: Domain type ('absorbance', 'reflectance', 'transmittance', 'km'). scatter_coeffs: Scattering coefficients for KM model (reflectance). scatter_wavelength_dep: Wavelength-dependent scatter (λ^-n). """ domain: Literal["absorbance", "reflectance", "transmittance", "km"] = "absorbance" scatter_coeffs: Optional[np.ndarray] = None scatter_wavelength_exp: float = 0.0 # For wavelength-dependent scatter
[docs] def transform( self, absorption: np.ndarray, wavelengths: np.ndarray, scatter: Optional[np.ndarray] = None, ) -> np.ndarray: """ Transform absorption to target domain. Args: absorption: Absorption coefficient K(λ). wavelengths: Wavelength grid. scatter: Scattering coefficient S(λ) for reflectance. Returns: Spectrum in target domain representation. """ if self.domain == "absorbance": return absorption elif self.domain == "transmittance": # T = exp(-A) for optical density return np.exp(-np.clip(absorption, 0, 10)) elif self.domain in ("reflectance", "km"): # Kubelka-Munk: f(R) = (1-R)²/(2R) = K/S # Solve for R given K and S # Get or compute scattering coefficient if scatter is not None: S = scatter elif self.scatter_coeffs is not None: # Use provided scatter coefficients with wavelength dependence S = self._compute_scatter(wavelengths) else: # Default scattering (smooth baseline) S = np.ones_like(absorption) * 0.5 # Avoid division by zero S = np.maximum(S, 1e-6) K = np.maximum(absorption, 0) # KM ratio km_ratio = K / S # Solve quadratic: (1-R)²/(2R) = km_ratio # R² - 2R(1 + km_ratio) + 1 = 0 # R = (1 + km_ratio) - sqrt((1 + km_ratio)² - 1) a = 1 + km_ratio discriminant = np.maximum(a**2 - 1, 0) R = a - np.sqrt(discriminant) # Clip to valid reflectance range R = np.clip(R, 0.01, 0.99) if self.domain == "km": # Return KM function value return km_ratio return R return absorption
[docs] def inverse_transform( self, spectrum: np.ndarray, wavelengths: np.ndarray, scatter: Optional[np.ndarray] = None, ) -> np.ndarray: """ Inverse transform from domain to absorption. Args: spectrum: Spectrum in domain representation. wavelengths: Wavelength grid. scatter: Scattering coefficient for reflectance. Returns: Absorption coefficient. """ if self.domain == "absorbance": return spectrum elif self.domain == "transmittance": return -np.log(np.clip(spectrum, 1e-6, 1)) elif self.domain in ("reflectance", "km"): if self.domain == "km": km_ratio = spectrum else: R = np.clip(spectrum, 0.01, 0.99) km_ratio = (1 - R) ** 2 / (2 * R) if scatter is not None: S = scatter elif self.scatter_coeffs is not None: S = self._compute_scatter(wavelengths) else: S = np.ones_like(spectrum) * 0.5 return km_ratio * S return spectrum
def _compute_scatter(self, wavelengths: np.ndarray) -> np.ndarray: """Compute wavelength-dependent scattering coefficient.""" if self.scatter_coeffs is None: return np.ones_like(wavelengths) * 0.5 # Baseline scatter with wavelength dependence wl_norm = wavelengths / 1000.0 # Normalize to μm S = self.scatter_coeffs[0] * np.ones_like(wavelengths) if len(self.scatter_coeffs) > 1: S += self.scatter_coeffs[1] * (wl_norm ** (-self.scatter_wavelength_exp)) if len(self.scatter_coeffs) > 2: # Polynomial terms wl_centered = (wavelengths - wavelengths.mean()) / 1000.0 for i, coeff in enumerate(self.scatter_coeffs[2:]): S += coeff * wl_centered ** (i + 1) return np.maximum(S, 1e-6)
# ============================================================================= # Preprocessing Operator # =============================================================================
[docs] @dataclass class PreprocessingOperator: """ Apply dataset preprocessing to match stored representation. Implements exact preprocessing steps: - Savitzky-Golay derivatives (1st, 2nd order) - SNV (Standard Normal Variate) - MSC (Multiplicative Scatter Correction) - Detrend - Mean centering Attributes: preprocessing_type: Type of preprocessing. sg_window: Savitzky-Golay window length. sg_polyorder: Savitzky-Golay polynomial order. sg_deriv: Derivative order (0, 1, 2). reference_spectrum: Reference for MSC (mean of calibration set). """ preprocessing_type: Literal[ "none", "first_derivative", "second_derivative", "snv", "msc", "detrend", "mean_centered" ] = "none" sg_window: int = 15 sg_polyorder: int = 2 sg_deriv: int = 0 reference_spectrum: Optional[np.ndarray] = None
[docs] def apply(self, spectrum: np.ndarray) -> np.ndarray: """ Apply preprocessing to spectrum. Args: spectrum: Input spectrum, shape (n_wavelengths,) or (n_samples, n_wavelengths). Returns: Preprocessed spectrum(a). """ is_1d = spectrum.ndim == 1 if is_1d: spectrum = spectrum.reshape(1, -1) result = spectrum.copy() if self.preprocessing_type == "none": pass elif self.preprocessing_type == "first_derivative": window = min(self.sg_window, spectrum.shape[1] - 1) | 1 polyorder = min(self.sg_polyorder, window - 1) result = savgol_filter( spectrum, window, polyorder, deriv=1, axis=1, mode="interp" ) elif self.preprocessing_type == "second_derivative": window = min(self.sg_window, spectrum.shape[1] - 1) | 1 polyorder = min(self.sg_polyorder + 1, window - 1) result = savgol_filter( spectrum, window, polyorder, deriv=2, axis=1, mode="interp" ) elif self.preprocessing_type == "snv": means = spectrum.mean(axis=1, keepdims=True) stds = spectrum.std(axis=1, keepdims=True) stds = np.where(stds < 1e-10, 1.0, stds) result = (spectrum - means) / stds elif self.preprocessing_type == "msc": if self.reference_spectrum is None: ref = spectrum.mean(axis=0) else: ref = self.reference_spectrum for i in range(spectrum.shape[0]): # Linear fit: spectrum[i] = a * ref + b coeffs = np.polyfit(ref, spectrum[i], 1) a, b = coeffs[0], coeffs[1] if abs(a) > 1e-10: result[i] = (spectrum[i] - b) / a elif self.preprocessing_type == "detrend": from scipy.signal import detrend result = detrend(spectrum, axis=1, type="linear") elif self.preprocessing_type == "mean_centered": result = spectrum - spectrum.mean(axis=1, keepdims=True) if is_1d: return result.ravel() return result
[docs] def apply_to_matrix(self, X: np.ndarray) -> np.ndarray: """Apply preprocessing to design matrix columns.""" return self.apply(X.T).T
[docs] @classmethod def from_detection( cls, preprocessing_type: str, sg_window: int = 15, sg_polyorder: int = 2, ) -> "PreprocessingOperator": """Create PreprocessingOperator from detected preprocessing type.""" type_map = { "raw_absorbance": "none", "raw_reflectance": "none", "first_derivative": "first_derivative", "second_derivative": "second_derivative", "snv_corrected": "snv", "msc_corrected": "msc", "mean_centered": "mean_centered", "normalized": "none", # Min-max scaling doesn't need special handling } prep_type = type_map.get(preprocessing_type, "none") return cls( preprocessing_type=prep_type, sg_window=sg_window, sg_polyorder=sg_polyorder, )
# ============================================================================= # Forward Chain # =============================================================================
[docs] @dataclass class ForwardChain: """ Complete forward measurement chain combining all components. Chain: CanonicalForwardModel → [EnvironmentalEffects] → DomainTransform → InstrumentModel → PreprocessingOperator Attributes: canonical_model: Physical model on canonical grid. environmental_model: Optional environmental effects (temperature, moisture, scattering). instrument_model: Instrument effects. domain_transform: Domain conversion. preprocessing: Dataset preprocessing. """ canonical_model: CanonicalForwardModel instrument_model: InstrumentModel domain_transform: DomainTransform preprocessing: PreprocessingOperator environmental_model: Optional["EnvironmentalEffectsModel"] = None
[docs] def forward( self, concentrations: np.ndarray, path_length: float = 1.0, baseline_coeffs: Optional[np.ndarray] = None, continuum_coeffs: Optional[np.ndarray] = None, scatter: Optional[np.ndarray] = None, ) -> np.ndarray: """ Run full forward chain. Args: concentrations: Component concentrations. path_length: Optical path length factor. baseline_coeffs: Baseline polynomial coefficients. continuum_coeffs: Continuum absorption coefficients. scatter: Scattering coefficients for reflectance. Returns: Spectrum on target grid with preprocessing applied. """ # 1. Compute absorption on canonical grid absorption = self.canonical_model.compute_absorption( concentrations=concentrations, path_length=path_length, baseline_coeffs=baseline_coeffs, continuum_coeffs=continuum_coeffs, ) # 2. Apply environmental effects (temperature, moisture, scattering) if self.environmental_model is not None and self.environmental_model.enabled: absorption = self.environmental_model.apply( absorption, self.canonical_model.canonical_grid, ) # 3. Apply domain transform (absorbance → reflectance if needed) domain_spectrum = self.domain_transform.transform( absorption, self.canonical_model.canonical_grid, scatter=scatter, ) # 4. Apply instrument effects and resample instrument_spectrum = self.instrument_model.apply( domain_spectrum, self.canonical_model.canonical_grid, ) # 5. Apply preprocessing preprocessed = self.preprocessing.apply(instrument_spectrum) return preprocessed
[docs] def forward_design_matrix( self, path_length: float = 1.0, ) -> np.ndarray: """ Get transformed design matrix for linear fitting. Returns the design matrix after applying instrument and preprocessing transforms. Note: Domain transform is not applied here as it may be nonlinear (KM). """ # Get canonical design matrix A_canonical = self.canonical_model.get_design_matrix(path_length) # Apply instrument transform to each column A_instrument = np.zeros( (len(self.instrument_model.target_grid), A_canonical.shape[1]) ) for j in range(A_canonical.shape[1]): A_instrument[:, j] = self.instrument_model.apply( A_canonical[:, j], self.canonical_model.canonical_grid, ) # Apply preprocessing A_preprocessed = self.preprocessing.apply_to_matrix(A_instrument) return A_preprocessed
[docs] @classmethod def create( cls, canonical_grid: np.ndarray, target_grid: np.ndarray, component_names: List[str], domain: str = "absorbance", preprocessing_type: str = "none", instrument_params: Optional[Dict[str, float]] = None, baseline_order: int = 5, continuum_order: int = 3, sg_window: int = 15, sg_polyorder: int = 2, include_environmental: bool = False, ) -> "ForwardChain": """ Convenience factory method to create ForwardChain. Args: canonical_grid: High-resolution canonical wavelength grid. target_grid: Target dataset wavelength grid. component_names: Names of components to include. domain: Domain type ('absorbance', 'reflectance'). preprocessing_type: Preprocessing type. instrument_params: Instrument parameters dict. baseline_order: Baseline polynomial order. continuum_order: Continuum polynomial order. sg_window: Savitzky-Golay window. sg_polyorder: Savitzky-Golay polynomial order. include_environmental: Whether to include environmental effects model. Returns: Configured ForwardChain instance. """ canonical_model = CanonicalForwardModel( canonical_grid=canonical_grid, component_names=component_names, baseline_order=baseline_order, continuum_order=continuum_order, ) instrument_params = instrument_params or {} instrument_model = InstrumentModel.from_params(target_grid, instrument_params) domain_transform = DomainTransform(domain=domain) preprocessing = PreprocessingOperator.from_detection( preprocessing_type, sg_window, sg_polyorder ) # Create environmental model if requested environmental_model = None if include_environmental: from .environmental import EnvironmentalEffectsModel environmental_model = EnvironmentalEffectsModel() return cls( canonical_model=canonical_model, instrument_model=instrument_model, domain_transform=domain_transform, preprocessing=preprocessing, environmental_model=environmental_model, )