Source code for nirs4all.data.synthetic.reconstruction.inversion

"""
Variable projection inversion for per-sample spectral fitting.

Implements the core inversion algorithm:
- Outer loop: optimize nonlinear parameters (path_length, per-sample shift)
- Inner loop: solve linear parameters via NNLS/QP (concentrations, baseline)

Uses multiscale schedule to avoid local minima.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy.optimize import minimize, lsq_linear
from scipy.ndimage import gaussian_filter1d


# =============================================================================
# Inversion Result
# =============================================================================


[docs] @dataclass class InversionResult: """ Result of per-sample inversion. Attributes: concentrations: Fitted component concentrations. baseline_coeffs: Fitted baseline coefficients. continuum_coeffs: Fitted continuum coefficients. path_length: Fitted path length factor. wl_shift_residual: Per-sample wavelength shift correction. scatter_coeffs: Fitted scatter coefficients (reflectance). fitted_spectrum: Reconstructed spectrum. residuals: Fitting residuals. r_squared: Coefficient of determination. rmse: Root mean squared error. converged: Whether optimization converged. temperature_delta: Fitted temperature deviation (°C from reference). water_activity: Fitted water activity (0-1 scale). scattering_power: Fitted scattering wavelength exponent. scattering_amplitude: Fitted scattering baseline amplitude. """ concentrations: np.ndarray baseline_coeffs: np.ndarray continuum_coeffs: Optional[np.ndarray] = None path_length: float = 1.0 wl_shift_residual: float = 0.0 scatter_coeffs: Optional[np.ndarray] = None fitted_spectrum: Optional[np.ndarray] = None residuals: Optional[np.ndarray] = None r_squared: float = 0.0 rmse: float = float("inf") converged: bool = False # Environmental parameters temperature_delta: float = 0.0 water_activity: float = 0.5 scattering_power: float = 1.5 scattering_amplitude: float = 0.0
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { "concentrations": self.concentrations.tolist(), "baseline_coeffs": self.baseline_coeffs.tolist(), "continuum_coeffs": ( self.continuum_coeffs.tolist() if self.continuum_coeffs is not None else None ), "path_length": self.path_length, "wl_shift_residual": self.wl_shift_residual, "r_squared": self.r_squared, "rmse": self.rmse, "converged": self.converged, "temperature_delta": self.temperature_delta, "water_activity": self.water_activity, "scattering_power": self.scattering_power, "scattering_amplitude": self.scattering_amplitude, }
@property def linear_params(self) -> np.ndarray: """Get all linear parameters as single array.""" params = [self.concentrations, self.baseline_coeffs] if self.continuum_coeffs is not None: params.append(self.continuum_coeffs) return np.concatenate(params)
# ============================================================================= # Multiscale Schedule # =============================================================================
[docs] @dataclass class MultiscaleSchedule: """ Configuration for multiscale fitting curriculum. Fits coarse features first, then progressively adds detail: 1. Smooth target + no derivatives + strong baseline prior 2. Less smooth + partial derivative weight 3. Full resolution + full preprocessing Attributes: smooth_sigmas: Gaussian sigma values for each stage (0 = no smoothing). derivative_weights: Weight on derivative space at each stage. baseline_regularization: Baseline regularization at each stage. max_iterations: Max iterations at each stage. """ smooth_sigmas: List[float] = field(default_factory=lambda: [15, 8, 4, 0]) derivative_weights: List[float] = field(default_factory=lambda: [0.0, 0.3, 0.7, 1.0]) baseline_regularization: List[float] = field( default_factory=lambda: [1e-3, 1e-4, 1e-5, 1e-6] ) max_iterations: List[int] = field(default_factory=lambda: [50, 100, 100, 200]) @property def n_stages(self) -> int: """Number of stages in schedule.""" return len(self.smooth_sigmas)
[docs] @classmethod def quick(cls) -> "MultiscaleSchedule": """Quick schedule for fast fitting.""" return cls( smooth_sigmas=[10, 0], derivative_weights=[0.0, 1.0], baseline_regularization=[1e-4, 1e-6], max_iterations=[50, 100], )
[docs] @classmethod def thorough(cls) -> "MultiscaleSchedule": """Thorough schedule for best accuracy.""" return cls( smooth_sigmas=[20, 12, 6, 3, 0], derivative_weights=[0.0, 0.2, 0.5, 0.8, 1.0], baseline_regularization=[1e-2, 1e-3, 1e-4, 1e-5, 1e-6], max_iterations=[50, 100, 100, 150, 200], )
# ============================================================================= # Variable Projection Solver # =============================================================================
[docs] @dataclass class VariableProjectionSolver: """ Variable projection solver for spectral inversion. Separates optimization into: - Nonlinear params: path_length, per-sample wl_shift, [environmental] (outer loop) - Linear params: concentrations, baseline, continuum (inner NNLS/QP) Attributes: path_length_bounds: Bounds for path length. wl_shift_bounds: Per-sample wavelength shift bounds. concentration_regularization: L2 reg on concentrations. baseline_smoothness_penalty: Penalty on baseline curvature. use_derivatives: Fit in derivative space (for derivative data). fit_environmental: Whether to fit environmental parameters. temperature_bounds: Bounds for temperature deviation (°C). water_activity_bounds: Bounds for water activity. scattering_power_bounds: Bounds for scattering exponent. scattering_amplitude_bounds: Bounds for scattering amplitude. environmental_prior_weight: Weight for environmental parameter priors. """ path_length_bounds: Tuple[float, float] = (0.5, 2.0) wl_shift_bounds: Tuple[float, float] = (-2.0, 2.0) concentration_regularization: float = 1e-6 baseline_smoothness_penalty: float = 1e-4 use_derivatives: bool = False verbose: bool = False # Environmental fitting options fit_environmental: bool = False temperature_bounds: Tuple[float, float] = (-15.0, 15.0) water_activity_bounds: Tuple[float, float] = (0.1, 0.9) scattering_power_bounds: Tuple[float, float] = (0.5, 3.0) scattering_amplitude_bounds: Tuple[float, float] = (0.0, 0.2) environmental_prior_weight: float = 0.1
[docs] def fit( self, target: np.ndarray, forward_chain: "ForwardChain", schedule: Optional[MultiscaleSchedule] = None, initial_params: Optional[Dict[str, float]] = None, ) -> InversionResult: """ Fit forward model to target spectrum. Args: target: Target spectrum to fit. forward_chain: Forward chain with calibrated global params. schedule: Multiscale fitting schedule. initial_params: Initial nonlinear parameters. Returns: InversionResult with fitted parameters. """ if schedule is None: schedule = MultiscaleSchedule() # Initialize parameters if initial_params is None: path_length = 1.0 wl_shift = 0.0 temp_delta = 0.0 water_act = 0.5 scatter_power = 1.5 scatter_amp = 0.0 else: path_length = initial_params.get("path_length", 1.0) wl_shift = initial_params.get("wl_shift_residual", 0.0) temp_delta = initial_params.get("temperature_delta", 0.0) water_act = initial_params.get("water_activity", 0.5) scatter_power = initial_params.get("scattering_power", 1.5) scatter_amp = initial_params.get("scattering_amplitude", 0.0) # Build parameter vector if self.fit_environmental: current_params = np.array([ path_length, wl_shift, temp_delta, water_act, scatter_power, scatter_amp ]) else: current_params = np.array([path_length, wl_shift]) # Run multiscale schedule for stage in range(schedule.n_stages): sigma = schedule.smooth_sigmas[stage] deriv_weight = schedule.derivative_weights[stage] baseline_reg = schedule.baseline_regularization[stage] max_iter = schedule.max_iterations[stage] # Smooth target for this stage if sigma > 0: target_stage = gaussian_filter1d(target, sigma=sigma) else: target_stage = target # Optimize nonlinear params result = self._optimize_stage( target_stage, forward_chain, current_params, deriv_weight=deriv_weight, baseline_reg=baseline_reg, max_iter=max_iter, ) # Update current params if self.fit_environmental: current_params = np.array([ result.path_length, result.wl_shift_residual, result.temperature_delta, result.water_activity, result.scattering_power, result.scattering_amplitude ]) else: current_params = np.array([result.path_length, result.wl_shift_residual]) if self.verbose: msg = f"Stage {stage + 1}/{schedule.n_stages}: R²={result.r_squared:.4f}, path_length={result.path_length:.3f}" if self.fit_environmental: msg += f", T={result.temperature_delta:.1f}°C, aw={result.water_activity:.2f}" print(msg) # Final fit with full resolution final_result = self._final_fit(target, forward_chain, current_params) return final_result
def _optimize_stage( self, target: np.ndarray, forward_chain: "ForwardChain", initial_params: np.ndarray, deriv_weight: float = 1.0, baseline_reg: float = 1e-6, max_iter: int = 100, ) -> InversionResult: """Optimize nonlinear params for one stage.""" # Store original instrument wl_shift and environmental state orig_wl_shift = forward_chain.instrument_model.wl_shift orig_env_state = None if forward_chain.environmental_model is not None: orig_env_state = forward_chain.environmental_model.to_dict() def objective(params: np.ndarray) -> float: path_length = params[0] wl_shift_delta = params[1] # Update instrument model with per-sample shift forward_chain.instrument_model.wl_shift = orig_wl_shift + wl_shift_delta # Update environmental model if fitting environmental if self.fit_environmental and forward_chain.environmental_model is not None: temp_delta = params[2] water_act = params[3] scatter_power = params[4] scatter_amp = params[5] forward_chain.environmental_model.temperature_delta = temp_delta forward_chain.environmental_model.water_activity = water_act forward_chain.environmental_model.scattering_power = scatter_power forward_chain.environmental_model.scattering_amplitude = scatter_amp # Get design matrix A = forward_chain.forward_design_matrix(path_length=path_length) # Solve inner linear problem try: linear_params, residual_norm = self._inner_solve( A, target, forward_chain, baseline_reg ) loss = residual_norm # Regularization on deviation from nominal loss += 0.01 * (path_length - 1.0) ** 2 loss += 0.1 * wl_shift_delta ** 2 # Environmental parameter priors if self.fit_environmental: weight = self.environmental_prior_weight # Gaussian prior on temperature (mean 0, std 5) loss += weight * (params[2] / 5.0) ** 2 # Beta-like prior on water_activity (centered at 0.5) aw_centered = params[3] - 0.5 loss += weight * 4 * aw_centered ** 2 # Gaussian prior on scattering power (mean 1.5, std 0.5) loss += weight * ((params[4] - 1.5) / 0.5) ** 2 # Exponential prior on scattering amplitude (encourage small) loss += weight * 10 * params[5] return loss except Exception: return 1e10 # Build bounds bounds = [self.path_length_bounds, self.wl_shift_bounds] if self.fit_environmental: bounds.extend([ self.temperature_bounds, self.water_activity_bounds, self.scattering_power_bounds, self.scattering_amplitude_bounds, ]) result = minimize( objective, initial_params, method="L-BFGS-B", bounds=bounds, options={"maxiter": max_iter}, ) # Extract final parameters path_length = result.x[0] wl_shift_delta = result.x[1] if self.fit_environmental: temp_delta = result.x[2] water_act = result.x[3] scatter_power = result.x[4] scatter_amp = result.x[5] else: temp_delta = 0.0 water_act = 0.5 scatter_power = 1.5 scatter_amp = 0.0 # Apply final parameters to forward chain forward_chain.instrument_model.wl_shift = orig_wl_shift + wl_shift_delta if self.fit_environmental and forward_chain.environmental_model is not None: forward_chain.environmental_model.temperature_delta = temp_delta forward_chain.environmental_model.water_activity = water_act forward_chain.environmental_model.scattering_power = scatter_power forward_chain.environmental_model.scattering_amplitude = scatter_amp A = forward_chain.forward_design_matrix(path_length=path_length) linear_params, _ = self._inner_solve(A, target, forward_chain, baseline_reg) # Compute fit statistics fitted = A @ linear_params residuals = target - fitted ss_res = np.sum(residuals ** 2) ss_tot = np.sum((target - target.mean()) ** 2) r_squared = 1 - ss_res / ss_tot if ss_tot > 0 else 0 # Restore original state forward_chain.instrument_model.wl_shift = orig_wl_shift if orig_env_state is not None and forward_chain.environmental_model is not None: forward_chain.environmental_model.temperature_delta = orig_env_state["temperature_delta"] forward_chain.environmental_model.water_activity = orig_env_state["water_activity"] forward_chain.environmental_model.scattering_power = orig_env_state["scattering_power"] forward_chain.environmental_model.scattering_amplitude = orig_env_state["scattering_amplitude"] # Parse linear params n_comp = forward_chain.canonical_model.n_components n_baseline = forward_chain.canonical_model.n_baseline n_continuum = forward_chain.canonical_model.n_continuum concentrations = linear_params[:n_comp] baseline_coeffs = linear_params[n_comp : n_comp + n_baseline] continuum_coeffs = ( linear_params[n_comp + n_baseline :] if n_continuum > 0 else None ) return InversionResult( concentrations=concentrations, baseline_coeffs=baseline_coeffs, continuum_coeffs=continuum_coeffs, path_length=path_length, wl_shift_residual=wl_shift_delta, fitted_spectrum=fitted, residuals=residuals, r_squared=r_squared, rmse=np.sqrt(np.mean(residuals ** 2)), converged=result.success, temperature_delta=temp_delta, water_activity=water_act, scattering_power=scatter_power, scattering_amplitude=scatter_amp, ) def _inner_solve( self, A: np.ndarray, target: np.ndarray, forward_chain: "ForwardChain", baseline_reg: float, ) -> Tuple[np.ndarray, float]: """ Solve inner linear problem with constraints. Constraints: - Concentrations >= 0 (NNLS) - Baseline/continuum: free (can be negative) - Optional smoothness penalty on baseline """ n_comp = forward_chain.canonical_model.n_components n_params = A.shape[1] # Bounds: concentrations >= 0, baseline free lb = np.concatenate([ np.zeros(n_comp), # concentrations >= 0 -np.inf * np.ones(n_params - n_comp), # baseline free ]) ub = np.inf * np.ones(n_params) # Add regularization if needed if baseline_reg > 0: # Augment system with regularization rows n_wl = A.shape[0] reg_matrix = baseline_reg * np.eye(n_params) reg_matrix[:n_comp, :n_comp] = self.concentration_regularization * np.eye( n_comp ) A_aug = np.vstack([A, reg_matrix]) target_aug = np.concatenate([target, np.zeros(n_params)]) result = lsq_linear(A_aug, target_aug, bounds=(lb, ub)) else: result = lsq_linear(A, target, bounds=(lb, ub)) residual_norm = np.sum((A @ result.x - target) ** 2) return result.x, residual_norm def _final_fit( self, target: np.ndarray, forward_chain: "ForwardChain", params: np.ndarray, ) -> InversionResult: """Final fit with full resolution and statistics.""" path_length = params[0] wl_shift_delta = params[1] # Extract environmental parameters if present if self.fit_environmental and len(params) >= 6: temp_delta = params[2] water_act = params[3] scatter_power = params[4] scatter_amp = params[5] else: temp_delta = 0.0 water_act = 0.5 scatter_power = 1.5 scatter_amp = 0.0 # Store original state orig_wl_shift = forward_chain.instrument_model.wl_shift orig_env_state = None if forward_chain.environmental_model is not None: orig_env_state = forward_chain.environmental_model.to_dict() # Apply per-sample parameters forward_chain.instrument_model.wl_shift = orig_wl_shift + wl_shift_delta if self.fit_environmental and forward_chain.environmental_model is not None: forward_chain.environmental_model.temperature_delta = temp_delta forward_chain.environmental_model.water_activity = water_act forward_chain.environmental_model.scattering_power = scatter_power forward_chain.environmental_model.scattering_amplitude = scatter_amp A = forward_chain.forward_design_matrix(path_length=path_length) linear_params, _ = self._inner_solve( A, target, forward_chain, self.baseline_smoothness_penalty ) # Compute fit fitted = A @ linear_params residuals = target - fitted ss_res = np.sum(residuals ** 2) ss_tot = np.sum((target - target.mean()) ** 2) r_squared = 1 - ss_res / ss_tot if ss_tot > 0 else 0 # Restore original state forward_chain.instrument_model.wl_shift = orig_wl_shift if orig_env_state is not None and forward_chain.environmental_model is not None: forward_chain.environmental_model.temperature_delta = orig_env_state["temperature_delta"] forward_chain.environmental_model.water_activity = orig_env_state["water_activity"] forward_chain.environmental_model.scattering_power = orig_env_state["scattering_power"] forward_chain.environmental_model.scattering_amplitude = orig_env_state["scattering_amplitude"] # Parse params n_comp = forward_chain.canonical_model.n_components n_baseline = forward_chain.canonical_model.n_baseline n_continuum = forward_chain.canonical_model.n_continuum return InversionResult( concentrations=linear_params[:n_comp], baseline_coeffs=linear_params[n_comp : n_comp + n_baseline], continuum_coeffs=( linear_params[n_comp + n_baseline :] if n_continuum > 0 else None ), path_length=path_length, wl_shift_residual=wl_shift_delta, fitted_spectrum=fitted, residuals=residuals, r_squared=r_squared, rmse=np.sqrt(np.mean(residuals ** 2)), converged=True, temperature_delta=temp_delta, water_activity=water_act, scattering_power=scatter_power, scattering_amplitude=scatter_amp, )
[docs] def fit_batch( self, X: np.ndarray, forward_chain: "ForwardChain", schedule: Optional[MultiscaleSchedule] = None, n_jobs: int = 1, ) -> List[InversionResult]: """ Fit multiple spectra. Args: X: Spectra matrix (n_samples, n_wavelengths). forward_chain: Forward chain with calibrated global params. schedule: Multiscale fitting schedule. n_jobs: Number of parallel jobs (1 = sequential). Returns: List of InversionResult for each sample. """ n_samples = X.shape[0] results = [] # Use quick schedule for batch to save time if schedule is None: schedule = MultiscaleSchedule.quick() for i in range(n_samples): result = self.fit(X[i], forward_chain, schedule) results.append(result) if self.verbose and (i + 1) % 10 == 0: print(f"Fitted {i + 1}/{n_samples} samples") return results
# ============================================================================= # Batch Inversion Helper # =============================================================================
[docs] def invert_dataset( X: np.ndarray, forward_chain: "ForwardChain", schedule: Optional[MultiscaleSchedule] = None, verbose: bool = False, fit_environmental: bool = False, ) -> Tuple[List[InversionResult], Dict[str, np.ndarray]]: """ Invert full dataset and collect parameter arrays. Args: X: Spectra matrix (n_samples, n_wavelengths). forward_chain: Forward chain. schedule: Fitting schedule. verbose: Print progress. fit_environmental: Whether to fit environmental parameters. Returns: Tuple of (results_list, params_dict) where params_dict contains: - concentrations: (n_samples, n_components) - baseline_coeffs: (n_samples, n_baseline) - path_lengths: (n_samples,) - wl_shifts: (n_samples,) - r_squared: (n_samples,) - temperature_deltas: (n_samples,) [if fit_environmental] - water_activities: (n_samples,) [if fit_environmental] - scattering_powers: (n_samples,) [if fit_environmental] - scattering_amplitudes: (n_samples,) [if fit_environmental] """ solver = VariableProjectionSolver(verbose=verbose, fit_environmental=fit_environmental) results = solver.fit_batch(X, forward_chain, schedule) n_samples = len(results) n_comp = forward_chain.canonical_model.n_components n_baseline = forward_chain.canonical_model.n_baseline params = { "concentrations": np.zeros((n_samples, n_comp)), "baseline_coeffs": np.zeros((n_samples, n_baseline)), "path_lengths": np.zeros(n_samples), "wl_shifts": np.zeros(n_samples), "r_squared": np.zeros(n_samples), } if fit_environmental: params["temperature_deltas"] = np.zeros(n_samples) params["water_activities"] = np.zeros(n_samples) params["scattering_powers"] = np.zeros(n_samples) params["scattering_amplitudes"] = np.zeros(n_samples) for i, result in enumerate(results): params["concentrations"][i] = result.concentrations params["baseline_coeffs"][i] = result.baseline_coeffs params["path_lengths"][i] = result.path_length params["wl_shifts"][i] = result.wl_shift_residual params["r_squared"][i] = result.r_squared if fit_environmental: params["temperature_deltas"][i] = result.temperature_delta params["water_activities"][i] = result.water_activity params["scattering_powers"][i] = result.scattering_power params["scattering_amplitudes"][i] = result.scattering_amplitude return results, params