"""
Prototype-based global calibration for instrument parameters.
This module implements global parameter estimation using representative
prototype spectra (median + quantiles + k-medoids).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from scipy.optimize import minimize, differential_evolution
# =============================================================================
# Calibration Result
# =============================================================================
[docs]
@dataclass
class CalibrationResult:
"""
Result of global calibration.
Attributes:
wl_shift: Calibrated wavelength shift.
wl_stretch: Calibrated wavelength stretch.
ils_sigma: Calibrated ILS width.
stray_light: Calibrated stray light fraction.
gain: Calibrated photometric gain.
offset: Calibrated photometric offset.
prototype_residuals: Residuals for each prototype.
prototype_r2: R² for each prototype.
total_loss: Total calibration loss.
"""
wl_shift: float = 0.0
wl_stretch: float = 1.0
ils_sigma: float = 4.0
stray_light: float = 0.0
gain: float = 1.0
offset: float = 0.0
prototype_residuals: Optional[np.ndarray] = None
prototype_r2: Optional[np.ndarray] = None
total_loss: float = float("inf")
[docs]
def to_dict(self) -> Dict[str, float]:
"""Convert to parameter dictionary."""
return {
"wl_shift": self.wl_shift,
"wl_stretch": self.wl_stretch,
"ils_sigma": self.ils_sigma,
"stray_light": self.stray_light,
"gain": self.gain,
"offset": self.offset,
}
[docs]
@classmethod
def from_array(cls, params: np.ndarray) -> "CalibrationResult":
"""Create from parameter array [wl_shift, wl_stretch, ils_sigma]."""
return cls(
wl_shift=params[0],
wl_stretch=params[1] if len(params) > 1 else 1.0,
ils_sigma=params[2] if len(params) > 2 else 4.0,
)
# =============================================================================
# Prototype Selector
# =============================================================================
[docs]
@dataclass
class PrototypeSelector:
"""
Select representative prototype spectra from a dataset.
Uses multiple strategies to ensure robust global calibration:
1. Median spectrum (robust central tendency)
2. Quantile spectra (25%, 75% in PC1)
3. K-medoids in PCA space (capture diversity)
Attributes:
n_prototypes: Number of prototypes to select.
include_median: Always include median spectrum.
include_quantiles: Include quantile spectra.
pca_components: Number of PCA components for clustering.
"""
n_prototypes: int = 5
include_median: bool = True
include_quantiles: bool = True
pca_components: int = 5
[docs]
def select(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Select prototype spectra.
Args:
X: Spectra matrix (n_samples, n_wavelengths).
Returns:
Tuple of (prototype_spectra, prototype_indices).
"""
n_samples = X.shape[0]
indices = []
prototypes = []
# 1. Median spectrum (synthesized, not actual sample)
if self.include_median:
median_spectrum = np.median(X, axis=0)
# Find closest actual sample to median
distances = np.sum((X - median_spectrum) ** 2, axis=1)
median_idx = np.argmin(distances)
if median_idx not in indices:
indices.append(median_idx)
prototypes.append(X[median_idx])
# 2. PCA for remaining selections
if len(indices) < self.n_prototypes:
from sklearn.decomposition import PCA
n_comp = min(self.pca_components, n_samples - 1, X.shape[1])
pca = PCA(n_components=n_comp)
scores = pca.fit_transform(X)
# 2a. Quantile spectra along PC1
if self.include_quantiles and len(indices) < self.n_prototypes:
pc1 = scores[:, 0]
q25_idx = np.argmin(np.abs(pc1 - np.percentile(pc1, 25)))
q75_idx = np.argmin(np.abs(pc1 - np.percentile(pc1, 75)))
for idx in [q25_idx, q75_idx]:
if idx not in indices and len(indices) < self.n_prototypes:
indices.append(idx)
prototypes.append(X[idx])
# 2b. K-medoids clustering for diversity
remaining = self.n_prototypes - len(indices)
if remaining > 0:
kmedoid_indices = self._kmedoids_selection(
scores, n_select=remaining, exclude=indices
)
for idx in kmedoid_indices:
if idx not in indices:
indices.append(idx)
prototypes.append(X[idx])
# Ensure we have at least one prototype
if len(prototypes) == 0:
indices = [0]
prototypes = [X[0]]
return np.array(prototypes), np.array(indices)
def _kmedoids_selection(
self,
scores: np.ndarray,
n_select: int,
exclude: List[int],
) -> List[int]:
"""Select diverse samples using k-medoids-like approach."""
n_samples = scores.shape[0]
available = [i for i in range(n_samples) if i not in exclude]
if len(available) == 0:
return []
selected = []
# Greedy selection: pick samples maximizing minimum distance to already selected
if exclude:
# Start with sample furthest from excluded samples
excluded_scores = scores[exclude]
distances = np.min(
np.sum((scores[available, np.newaxis, :] - excluded_scores) ** 2, axis=2),
axis=1,
)
first_idx = available[np.argmax(distances)]
else:
# Start with sample closest to centroid
centroid = scores.mean(axis=0)
distances = np.sum((scores[available] - centroid) ** 2, axis=1)
first_idx = available[np.argmin(distances)]
selected.append(first_idx)
# Greedily add samples
while len(selected) < n_select and len(selected) < len(available):
remaining = [i for i in available if i not in selected]
if not remaining:
break
selected_scores = scores[selected]
distances = np.min(
np.sum(
(scores[remaining, np.newaxis, :] - selected_scores) ** 2, axis=2
),
axis=1,
)
next_idx = remaining[np.argmax(distances)]
selected.append(next_idx)
return selected
# =============================================================================
# Global Calibrator
# =============================================================================
[docs]
@dataclass
class GlobalCalibrator:
"""
Calibrate global instrument parameters using prototype spectra.
Optimizes θ_global = {wl_shift, wl_stretch, ils_sigma} to minimize
total fitting loss across all prototypes, with per-prototype linear
parameters solved via NNLS.
Attributes:
forward_chain: ForwardChain for computing model predictions.
wl_shift_bounds: Bounds for wavelength shift.
wl_stretch_bounds: Bounds for wavelength stretch.
ils_sigma_bounds: Bounds for ILS sigma.
regularization: L2 regularization strength.
use_global_search: Use differential evolution for global search.
"""
wl_shift_bounds: Tuple[float, float] = (-10.0, 10.0)
wl_stretch_bounds: Tuple[float, float] = (0.98, 1.02)
ils_sigma_bounds: Tuple[float, float] = (2.0, 20.0)
regularization: float = 1e-6
use_global_search: bool = False
[docs]
def calibrate(
self,
prototypes: np.ndarray,
forward_chain: "ForwardChain",
initial_guess: Optional[np.ndarray] = None,
) -> CalibrationResult:
"""
Calibrate global parameters on prototype spectra.
Args:
prototypes: Prototype spectra (n_prototypes, n_wavelengths).
forward_chain: Forward chain for model evaluation.
initial_guess: Initial [wl_shift, wl_stretch, ils_sigma].
Returns:
CalibrationResult with optimized parameters.
"""
from scipy.optimize import nnls
n_prototypes = prototypes.shape[0]
n_wl = prototypes.shape[1]
# Store forward chain reference
self._forward_chain = forward_chain
def objective(params: np.ndarray) -> float:
"""Total loss across all prototypes."""
wl_shift, wl_stretch, ils_sigma = params
# Update instrument model
forward_chain.instrument_model.wl_shift = wl_shift
forward_chain.instrument_model.wl_stretch = wl_stretch
forward_chain.instrument_model.ils_sigma = ils_sigma
total_loss = 0.0
for i in range(n_prototypes):
try:
# Get design matrix with current instrument params
A = forward_chain.forward_design_matrix(path_length=1.0)
# NNLS for linear params (concentrations >= 0, baseline free)
n_comp = forward_chain.canonical_model.n_components
n_linear = A.shape[1]
# Use bounded least squares to enforce concentration >= 0
from scipy.optimize import lsq_linear
lb = np.concatenate([
np.zeros(n_comp), # concentrations >= 0
-np.inf * np.ones(n_linear - n_comp), # baseline free
])
ub = np.inf * np.ones(n_linear)
result = lsq_linear(A, prototypes[i], bounds=(lb, ub))
residuals = prototypes[i] - A @ result.x
# Weighted by prototype (equal weights)
loss = np.sum(residuals ** 2)
total_loss += loss
except Exception:
total_loss += 1e6
# Add regularization on parameter deviation from defaults
reg_loss = self.regularization * (
wl_shift ** 2 +
100 * (wl_stretch - 1.0) ** 2 +
0.1 * (ils_sigma - 6.0) ** 2
)
return total_loss + reg_loss
# Initial guess
if initial_guess is None:
initial_guess = np.array([0.0, 1.0, 6.0])
bounds = [
self.wl_shift_bounds,
self.wl_stretch_bounds,
self.ils_sigma_bounds,
]
# Optimize
if self.use_global_search:
result = differential_evolution(
objective,
bounds=bounds,
seed=42,
maxiter=50,
polish=True,
)
else:
result = minimize(
objective,
initial_guess,
method="L-BFGS-B",
bounds=bounds,
options={"maxiter": 200},
)
# Extract results
wl_shift, wl_stretch, ils_sigma = result.x
# Update forward chain with calibrated params
forward_chain.instrument_model.wl_shift = wl_shift
forward_chain.instrument_model.wl_stretch = wl_stretch
forward_chain.instrument_model.ils_sigma = ils_sigma
# Compute per-prototype R²
prototype_r2 = np.zeros(n_prototypes)
prototype_residuals = []
for i in range(n_prototypes):
A = forward_chain.forward_design_matrix(path_length=1.0)
n_comp = forward_chain.canonical_model.n_components
n_linear = A.shape[1]
from scipy.optimize import lsq_linear
lb = np.concatenate([
np.zeros(n_comp),
-np.inf * np.ones(n_linear - n_comp),
])
ub = np.inf * np.ones(n_linear)
fit_result = lsq_linear(A, prototypes[i], bounds=(lb, ub))
fitted = A @ fit_result.x
residuals = prototypes[i] - fitted
ss_res = np.sum(residuals ** 2)
ss_tot = np.sum((prototypes[i] - prototypes[i].mean()) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
prototype_r2[i] = r2
prototype_residuals.append(residuals)
return CalibrationResult(
wl_shift=wl_shift,
wl_stretch=wl_stretch,
ils_sigma=ils_sigma,
prototype_residuals=np.array(prototype_residuals),
prototype_r2=prototype_r2,
total_loss=result.fun,
)
[docs]
def refine(
self,
current_result: CalibrationResult,
prototypes: np.ndarray,
forward_chain: "ForwardChain",
) -> CalibrationResult:
"""
Refine calibration with tighter bounds around current estimate.
Args:
current_result: Current calibration result.
prototypes: Prototype spectra.
forward_chain: Forward chain.
Returns:
Refined CalibrationResult.
"""
# Tighten bounds
margin = 0.5
self.wl_shift_bounds = (
current_result.wl_shift - margin * 2,
current_result.wl_shift + margin * 2,
)
self.wl_stretch_bounds = (
max(0.98, current_result.wl_stretch - 0.005),
min(1.02, current_result.wl_stretch + 0.005),
)
self.ils_sigma_bounds = (
max(2.0, current_result.ils_sigma - 2),
min(20.0, current_result.ils_sigma + 2),
)
initial = np.array([
current_result.wl_shift,
current_result.wl_stretch,
current_result.ils_sigma,
])
return self.calibrate(prototypes, forward_chain, initial_guess=initial)
# =============================================================================
# Multi-stage Calibration
# =============================================================================
[docs]
def multistage_calibration(
X: np.ndarray,
forward_chain: "ForwardChain",
n_prototypes: int = 5,
stages: int = 2,
) -> CalibrationResult:
"""
Multi-stage calibration with progressive refinement.
Stage 1: Coarse calibration on smoothed prototypes
Stage 2: Fine calibration on original prototypes
Args:
X: Full dataset (n_samples, n_wavelengths).
forward_chain: Forward chain for model evaluation.
n_prototypes: Number of prototypes to select.
stages: Number of refinement stages.
Returns:
Final CalibrationResult.
"""
from scipy.ndimage import gaussian_filter1d
# Select prototypes
selector = PrototypeSelector(n_prototypes=n_prototypes)
prototypes, indices = selector.select(X)
calibrator = GlobalCalibrator()
# Stage 1: Coarse on smoothed
smooth_sigmas = [10, 5, 0][:stages + 1]
result = None
for sigma in smooth_sigmas:
if sigma > 0:
protos_smooth = gaussian_filter1d(prototypes, sigma=sigma, axis=1)
else:
protos_smooth = prototypes
if result is None:
result = calibrator.calibrate(protos_smooth, forward_chain)
else:
result = calibrator.refine(result, protos_smooth, forward_chain)
return result