"""
Complete reconstruction pipeline for end-to-end workflow.
Provides a unified interface for:
1. Dataset configuration and preprocessing detection
2. Global calibration
3. Batch inversion
4. Parameter distribution learning
5. Synthetic generation
6. Validation
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
# =============================================================================
# Dataset Configuration
# =============================================================================
[docs]
@dataclass
class DatasetConfig:
"""
Configuration for a dataset to be reconstructed.
Captures all dataset-specific information needed for reconstruction:
- Wavelength grid
- Signal type (absorbance, reflectance)
- Preprocessing applied
- Application domain (for component selection)
Attributes:
wavelengths: Wavelength grid in nm.
signal_type: Signal type ('absorbance', 'reflectance').
preprocessing: Detected or specified preprocessing type.
domain: Application domain for component selection.
sg_window: Savitzky-Golay window (for derivatives).
sg_polyorder: Savitzky-Golay polynomial order.
name: Optional dataset name.
"""
wavelengths: np.ndarray
signal_type: Literal["absorbance", "reflectance", "unknown"] = "absorbance"
preprocessing: Literal[
"none", "first_derivative", "second_derivative", "snv", "msc", "unknown"
] = "none"
domain: str = "unknown"
sg_window: int = 15
sg_polyorder: int = 2
name: str = "dataset"
[docs]
@classmethod
def from_data(
cls,
X: np.ndarray,
wavelengths: np.ndarray,
name: str = "dataset",
) -> "DatasetConfig":
"""
Create configuration by auto-detecting properties from data.
Args:
X: Spectra matrix (n_samples, n_wavelengths).
wavelengths: Wavelength grid.
name: Dataset name.
Returns:
DatasetConfig with detected properties.
"""
# Detect signal type
signal_type = cls._detect_signal_type(X, wavelengths)
# Detect preprocessing
preprocessing = cls._detect_preprocessing(X)
return cls(
wavelengths=wavelengths.copy(),
signal_type=signal_type,
preprocessing=preprocessing,
name=name,
)
@staticmethod
def _detect_signal_type(
X: np.ndarray, wavelengths: np.ndarray
) -> Literal["absorbance", "reflectance", "unknown"]:
"""Detect signal type from data characteristics."""
mean_val = X.mean()
max_val = X.max()
min_val = X.min()
# Derivative detection
if min_val < -0.5 or (min_val < 0 and abs(mean_val) < 0.1):
return "unknown" # Derivative data
# Reflectance: typically 0-1 or 0-100
if 0 <= min_val and max_val <= 1.1:
return "reflectance"
if 0 <= min_val and 20 < mean_val < 80:
return "reflectance" # Percent reflectance
# Absorbance: positive, typical range 0-3
if min_val >= 0 and 0.1 < mean_val < 3.0:
return "absorbance"
return "unknown"
@staticmethod
def _detect_preprocessing(
X: np.ndarray,
) -> Literal["none", "first_derivative", "second_derivative", "snv", "msc", "unknown"]:
"""Detect preprocessing type from data characteristics."""
mean_val = float(np.mean(X))
min_val = float(np.min(X))
max_val = float(np.max(X))
global_range = max_val - min_val
# Zero-crossing ratio
zero_crossings = np.sum(np.diff(np.sign(X), axis=1) != 0)
total_transitions = (X.shape[0] * (X.shape[1] - 1))
zero_crossing_ratio = zero_crossings / max(total_transitions, 1)
# Second derivative: very small range, zero mean, high oscillation
if global_range < 0.3 and abs(mean_val) < 0.05 and zero_crossing_ratio > 0.15:
return "second_derivative"
# First derivative: bipolar, near-zero mean
if min_val < -0.001 and max_val > 0.001 and abs(mean_val) < 0.1:
if global_range < 1.0 or zero_crossing_ratio > 0.05:
return "first_derivative"
# SNV: per-sample std ~1
sample_stds = X.std(axis=1)
if 0.8 < np.mean(sample_stds) < 1.2 and np.std(sample_stds) < 0.2:
sample_means = X.mean(axis=1)
if abs(np.mean(sample_means)) < 0.1:
return "snv"
# Raw data
if min_val >= 0 and mean_val > 0.1:
return "none"
return "unknown"
# =============================================================================
# Pipeline Result
# =============================================================================
[docs]
@dataclass
class PipelineResult:
"""
Result of reconstruction pipeline.
Contains all outputs from the reconstruction workflow:
- Calibration results
- Inversion results
- Learned distributions
- Generated synthetic data
- Validation metrics
Attributes:
config: Dataset configuration used.
calibration: Global calibration result.
inversion_results: Per-sample inversion results.
distribution: Learned parameter distributions.
X_synthetic: Generated synthetic spectra.
validation: Validation result.
forward_chain: Calibrated forward chain.
"""
config: DatasetConfig
calibration: Optional["CalibrationResult"] = None
inversion_results: Optional[List["InversionResult"]] = None
distribution: Optional["DistributionResult"] = None
X_synthetic: Optional[np.ndarray] = None
validation: Optional["ValidationResult"] = None
forward_chain: Optional["ForwardChain"] = None
[docs]
def summary(self) -> str:
"""Generate pipeline summary."""
lines = [
"=" * 70,
f"Reconstruction Pipeline Result: {self.config.name}",
"=" * 70,
"",
"Dataset Configuration:",
f" Signal type: {self.config.signal_type}",
f" Preprocessing: {self.config.preprocessing}",
f" Wavelengths: {len(self.config.wavelengths)} points",
f" Range: {self.config.wavelengths.min():.0f} - {self.config.wavelengths.max():.0f} nm",
]
if self.calibration:
lines.extend([
"",
"Global Calibration:",
f" Wavelength shift: {self.calibration.wl_shift:.2f} nm",
f" ILS sigma: {self.calibration.ils_sigma:.2f} nm",
f" Total loss: {self.calibration.total_loss:.4f}",
])
if self.inversion_results:
r2_values = [r.r_squared for r in self.inversion_results]
lines.extend([
"",
"Inversion Results:",
f" Samples fitted: {len(self.inversion_results)}",
f" Mean R²: {np.mean(r2_values):.4f}",
f" Min R²: {np.min(r2_values):.4f}",
])
if self.validation:
lines.extend([
"",
"Validation:",
f" Overall score: {self.validation.overall_score:.1f}/100",
f" Status: {'PASSED' if self.validation.passed else 'NEEDS REVIEW'}",
])
if self.validation.warnings:
for w in self.validation.warnings:
lines.append(f" Warning: {w}")
lines.append("=" * 70)
return "\n".join(lines)
# =============================================================================
# Reconstruction Pipeline
# =============================================================================
[docs]
@dataclass
class ReconstructionPipeline:
"""
Complete reconstruction pipeline.
Orchestrates the full workflow:
1. Configuration and component selection
2. Prototype selection and global calibration
3. Per-sample inversion (optionally with environmental parameters)
4. Parameter distribution learning
5. Synthetic generation
6. Validation
Attributes:
config: Dataset configuration.
component_names: Components to use (auto-selected if None).
canonical_resolution: Resolution of canonical grid (nm).
baseline_order: Baseline polynomial order.
n_prototypes: Number of prototypes for calibration.
fit_environmental: Whether to fit environmental parameters.
verbose: Print progress.
"""
config: DatasetConfig
component_names: Optional[List[str]] = None
canonical_resolution: float = 0.5
baseline_order: int = 5
continuum_order: int = 3
n_prototypes: int = 5
fit_environmental: bool = False
verbose: bool = True
[docs]
def __post_init__(self):
"""Initialize components if not provided."""
if self.component_names is None:
self.component_names = self._select_components_for_domain()
def _select_components_for_domain(self) -> List[str]:
"""Select appropriate components based on domain."""
# Default components that appear in most NIR datasets
default_components = [
"water", "protein", "lipid", "starch", "cellulose",
]
# Domain-specific additions
domain_components = {
"food_dairy": ["casein", "lactose", "whey", "lipid"],
"food_bakery": ["starch", "gluten", "lipid", "glucose"],
"agriculture_grain": ["starch", "protein", "cellulose", "moisture"],
"agriculture_fruit": ["fructose", "glucose", "cellulose", "water"],
"environmental_soil": ["humic_acid", "cellulose", "clay_minerals"],
"pharma_tablets": ["lactose", "cellulose", "starch"],
"petrochem_fuels": ["paraffin", "aromatic_hydrocarbons"],
"beverage_wine": ["ethanol", "glucose", "water"],
}
domain = self.config.domain
if domain in domain_components:
components = domain_components[domain]
else:
components = default_components
# Filter to available components
from ..components import available_components
available = set(available_components())
return [c for c in components if c in available][:10]
[docs]
def fit(
self,
X: np.ndarray,
max_samples: Optional[int] = None,
) -> PipelineResult:
"""
Run full reconstruction pipeline.
Args:
X: Spectra matrix (n_samples, n_wavelengths).
max_samples: Max samples to invert (for speed).
Returns:
PipelineResult with all outputs.
"""
from .forward import ForwardChain
from .calibration import PrototypeSelector, GlobalCalibrator, multistage_calibration
from .inversion import VariableProjectionSolver, MultiscaleSchedule
from .distributions import ParameterDistributionFitter, ParameterSampler
from .generator import ReconstructionGenerator, estimate_noise_from_residuals
from .validation import ReconstructionValidator
n_samples = X.shape[0]
if self.verbose:
print(f"Starting reconstruction pipeline for {self.config.name}")
print(f" Samples: {n_samples}, Wavelengths: {X.shape[1]}")
print(f" Signal type: {self.config.signal_type}")
print(f" Preprocessing: {self.config.preprocessing}")
print(f" Components: {self.component_names}")
print(f" Environmental fitting: {self.fit_environmental}")
# 1. Create canonical grid
wl_min = self.config.wavelengths.min() - 50
wl_max = self.config.wavelengths.max() + 50
canonical_grid = np.arange(wl_min, wl_max, self.canonical_resolution)
# 2. Create forward chain
forward_chain = ForwardChain.create(
canonical_grid=canonical_grid,
target_grid=self.config.wavelengths,
component_names=self.component_names,
domain=self.config.signal_type if self.config.signal_type != "unknown" else "absorbance",
preprocessing_type=self.config.preprocessing if self.config.preprocessing != "unknown" else "none",
baseline_order=self.baseline_order,
continuum_order=self.continuum_order,
sg_window=self.config.sg_window,
sg_polyorder=self.config.sg_polyorder,
include_environmental=self.fit_environmental,
)
if self.verbose:
print("\n1. Global Calibration...")
# 3. Global calibration
calibration = multistage_calibration(
X, forward_chain, n_prototypes=self.n_prototypes
)
if self.verbose:
print(f" Wavelength shift: {calibration.wl_shift:.2f} nm")
print(f" ILS sigma: {calibration.ils_sigma:.2f} nm")
print(f" Prototype R²: {np.mean(calibration.prototype_r2):.4f}")
# 4. Per-sample inversion
if self.verbose:
print("\n2. Per-sample Inversion...")
if max_samples is not None and n_samples > max_samples:
# Subsample for speed
idx = np.random.choice(n_samples, max_samples, replace=False)
X_invert = X[idx]
else:
X_invert = X
solver = VariableProjectionSolver(
verbose=False,
fit_environmental=self.fit_environmental,
)
schedule = MultiscaleSchedule.quick() if n_samples > 100 else MultiscaleSchedule()
inversion_results = solver.fit_batch(X_invert, forward_chain, schedule)
r2_values = [r.r_squared for r in inversion_results]
if self.verbose:
print(f" Fitted {len(inversion_results)} samples")
print(f" Mean R²: {np.mean(r2_values):.4f}")
print(f" Min R²: {np.min(r2_values):.4f}")
# 5. Learn parameter distributions
if self.verbose:
print("\n3. Learning Parameter Distributions...")
params = {
"concentrations": np.array([r.concentrations for r in inversion_results]),
"baseline_coeffs": np.array([r.baseline_coeffs for r in inversion_results]),
"path_lengths": np.array([r.path_length for r in inversion_results]),
"wl_shifts": np.array([r.wl_shift_residual for r in inversion_results]),
}
# Add environmental parameters if fitted
if self.fit_environmental:
params["temperature_deltas"] = np.array(
[r.temperature_delta for r in inversion_results]
)
params["water_activities"] = np.array(
[r.water_activity for r in inversion_results]
)
params["scattering_powers"] = np.array(
[r.scattering_power for r in inversion_results]
)
params["scattering_amplitudes"] = np.array(
[r.scattering_amplitude for r in inversion_results]
)
# Configure distribution fitter for environmental params
bounded_params = {"wl_shifts": (-5.0, 5.0)}
positive_params = ["concentrations", "path_lengths"]
if self.fit_environmental:
bounded_params["water_activities"] = (0.0, 1.0)
bounded_params["scattering_powers"] = (0.5, 3.0)
positive_params.append("scattering_amplitudes")
dist_fitter = ParameterDistributionFitter(
positive_params=positive_params,
bounded_params=bounded_params,
)
distribution = dist_fitter.fit(params)
if self.verbose:
print(f" Fitted distributions for {len(distribution.param_names)} parameters")
# 6. Generate synthetic data
if self.verbose:
print("\n4. Generating Synthetic Data...")
sampler = ParameterSampler(distribution, use_correlations=True)
# Estimate noise from residuals
noise_add, noise_mult = estimate_noise_from_residuals(inversion_results)
generator = ReconstructionGenerator(
noise_level=noise_add,
multiplicative_noise=noise_mult,
add_noise=True,
)
gen_result = generator.generate(
n_samples=len(X_invert),
forward_chain=forward_chain,
sampler=sampler,
random_state=42,
)
X_synthetic = gen_result.X
if self.verbose:
print(f" Generated {len(X_synthetic)} synthetic samples")
# 7. Validation
if self.verbose:
print("\n5. Validation...")
validator = ReconstructionValidator()
validation = validator.validate(inversion_results, X_invert, X_synthetic)
if self.verbose:
print(f" Overall score: {validation.overall_score:.1f}/100")
print(f" Status: {'PASSED' if validation.passed else 'NEEDS REVIEW'}")
return PipelineResult(
config=self.config,
calibration=calibration,
inversion_results=inversion_results,
distribution=distribution,
X_synthetic=X_synthetic,
validation=validation,
forward_chain=forward_chain,
)
[docs]
def generate(
self,
n_samples: int,
result: PipelineResult,
random_state: Optional[int] = None,
) -> np.ndarray:
"""
Generate additional synthetic samples using fitted pipeline.
Args:
n_samples: Number of samples to generate.
result: PipelineResult from fit().
random_state: Random seed.
Returns:
Synthetic spectra matrix.
"""
from .distributions import ParameterSampler
from .generator import ReconstructionGenerator, estimate_noise_from_residuals
if result.distribution is None or result.forward_chain is None:
raise ValueError("Pipeline not fitted. Call fit() first.")
sampler = ParameterSampler(result.distribution, use_correlations=True)
noise_add, noise_mult = estimate_noise_from_residuals(result.inversion_results)
generator = ReconstructionGenerator(
noise_level=noise_add,
multiplicative_noise=noise_mult,
)
gen_result = generator.generate(
n_samples=n_samples,
forward_chain=result.forward_chain,
sampler=sampler,
random_state=random_state,
)
return gen_result.X
# =============================================================================
# Convenience Functions
# =============================================================================
[docs]
def reconstruct_and_generate(
X: np.ndarray,
wavelengths: np.ndarray,
n_synthetic: Optional[int] = None,
domain: str = "unknown",
component_names: Optional[List[str]] = None,
fit_environmental: bool = False,
verbose: bool = True,
) -> Tuple[np.ndarray, PipelineResult]:
"""
Convenience function for end-to-end reconstruction and generation.
Args:
X: Real spectra matrix.
wavelengths: Wavelength grid.
n_synthetic: Number of synthetic samples (default: same as X).
domain: Application domain.
component_names: Components to use.
fit_environmental: Whether to fit environmental parameters
(temperature, water activity, scattering).
verbose: Print progress.
Returns:
Tuple of (X_synthetic, PipelineResult).
"""
# Create configuration
config = DatasetConfig.from_data(X, wavelengths)
config.domain = domain
# Create and run pipeline
pipeline = ReconstructionPipeline(
config=config,
component_names=component_names,
fit_environmental=fit_environmental,
verbose=verbose,
)
result = pipeline.fit(X)
# Generate additional samples if requested
if n_synthetic is not None and n_synthetic != len(X):
X_synth = pipeline.generate(n_synthetic, result)
else:
X_synth = result.X_synthetic
return X_synth, result