Source code for nirs4all.data.synthetic.fitter

"""
Real data fitting utilities for synthetic NIRS spectra generation.

This module provides tools to analyze real NIRS datasets and fit generator
parameters to match their statistical and spectral properties.

Key Features:
    - Statistical property analysis (mean, std, skewness, kurtosis)
    - Spectral shape analysis (slope, curvature, noise)
    - PCA structure analysis
    - Parameter estimation for SyntheticNIRSGenerator
    - Comparison between synthetic and real data
    - **Phase 1-4 Enhanced Features:**
        - Instrument archetype inference (InGaAs, PbS, MEMS, etc.)
        - Measurement mode detection (transmittance, reflectance, ATR)
        - Application domain suggestion (agriculture, pharmaceutical, etc.)
        - Environmental effects estimation (temperature, moisture)
        - Scattering parameter estimation (particle size, EMSC)
        - Wavenumber-based peak analysis for component identification

Example:
    >>> from nirs4all.data.synthetic import RealDataFitter, SyntheticNIRSGenerator
    >>>
    >>> # Analyze real data
    >>> fitter = RealDataFitter()
    >>> params = fitter.fit(X_real, wavelengths=wavelengths)
    >>>
    >>> # Create generator with fitted parameters (includes all Phase 1-4 features)
    >>> generator = fitter.create_matched_generator()
    >>> X_synthetic, _, _ = generator.generate(n_samples=1000)
    >>>
    >>> # Or get all inferred characteristics
    >>> print(f"Inferred instrument: {params.inferred_instrument}")
    >>> print(f"Inferred domain: {params.inferred_domain}")
    >>> print(f"Measurement mode: {params.measurement_mode}")

References:
    - Based on comparator.py from bench/synthetic/
    - Enhanced with Phase 1-4 synthetic generator features
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
from scipy import stats
from scipy.signal import savgol_filter, find_peaks
from scipy.ndimage import gaussian_filter1d

if TYPE_CHECKING:
    from nirs4all.data.dataset import SpectroDataset
    from .generator import SyntheticNIRSGenerator


# ============================================================================
# Inference Result Classes
# ============================================================================


[docs] class MeasurementModeInference(str, Enum): """Inferred measurement mode from spectral analysis.""" TRANSMITTANCE = "transmittance" REFLECTANCE = "reflectance" TRANSFLECTANCE = "transflectance" ATR = "atr" UNKNOWN = "unknown"
[docs] @dataclass class InstrumentInference: """ Results of instrument archetype inference. Attributes: archetype_name: Best matching instrument archetype name. detector_type: Inferred detector type. wavelength_range: Detected wavelength range. estimated_resolution: Estimated spectral resolution (nm). confidence: Confidence score (0-1). alternative_archetypes: Other possible archetypes with scores. """ archetype_name: str = "unknown" detector_type: str = "unknown" wavelength_range: Tuple[float, float] = (1000.0, 2500.0) estimated_resolution: float = 8.0 confidence: float = 0.0 alternative_archetypes: Dict[str, float] = field(default_factory=dict)
[docs] @dataclass class DomainInference: """ Results of application domain inference. Attributes: domain_name: Best matching domain name. category: Domain category. confidence: Confidence score (0-1). detected_components: Components detected from peak analysis. alternative_domains: Other possible domains with scores. """ domain_name: str = "unknown" category: str = "unknown" confidence: float = 0.0 detected_components: List[str] = field(default_factory=list) alternative_domains: Dict[str, float] = field(default_factory=dict)
[docs] @dataclass class EnvironmentalInference: """ Results of environmental effects inference. Attributes: estimated_temperature_variation: Estimated temperature variation (°C). has_temperature_effects: Whether temperature effects are detectable. estimated_moisture_variation: Estimated moisture variation. has_moisture_effects: Whether moisture effects are detectable. water_band_shift: Detected shift in water bands (nm). """ estimated_temperature_variation: float = 0.0 has_temperature_effects: bool = False estimated_moisture_variation: float = 0.0 has_moisture_effects: bool = False water_band_shift: float = 0.0
[docs] @dataclass class ScatteringInference: """ Results of scattering effects inference. Attributes: has_scatter_effects: Whether significant scatter is detected. estimated_particle_size_um: Estimated mean particle size (μm). multiplicative_scatter_std: Estimated MSC-style multiplicative scatter. additive_scatter_std: Estimated SNV-style additive scatter. baseline_curvature: Detected baseline curvature intensity. snv_correctable: Whether SNV would improve spectra. msc_correctable: Whether MSC would improve spectra. """ has_scatter_effects: bool = False estimated_particle_size_um: float = 50.0 multiplicative_scatter_std: float = 0.0 additive_scatter_std: float = 0.0 baseline_curvature: float = 0.0 snv_correctable: bool = False msc_correctable: bool = False
[docs] @dataclass class SpectralProperties: """ Container for computed spectral properties of a dataset. This dataclass holds various statistical and spectral properties computed from a NIRS dataset for comparison and fitting purposes. Attributes: name: Dataset identifier. n_samples: Number of samples. n_wavelengths: Number of wavelengths. wavelengths: Wavelength grid. # Basic statistics mean_spectrum: Mean spectrum across samples. std_spectrum: Standard deviation spectrum. global_mean: Overall mean absorbance. global_std: Overall standard deviation. global_range: (min, max) absorbance range. # Shape properties mean_slope: Average spectral slope (per 1000nm). slope_std: Standard deviation of slopes. mean_curvature: Average curvature (second derivative). # Distribution statistics skewness: Skewness of absorbance distribution. kurtosis: Kurtosis of absorbance distribution. # Noise characteristics noise_estimate: Estimated noise level. snr_estimate: Signal-to-noise ratio estimate. # PCA properties pca_explained_variance: Explained variance ratios. pca_n_components_95: Components for 95% variance. # Peak analysis n_peaks_mean: Mean number of peaks. peak_positions: Wavelengths of detected peaks. peak_wavenumbers: Wavenumber positions of peaks. # Phase 1-4 Enhanced properties # Instrument indicators effective_resolution: Estimated spectral resolution from peak widths. noise_correlation_length: Correlation length of noise (detector indicator). wavelength_range: Actual wavelength range of data. # Measurement mode indicators baseline_offset: Mean baseline offset (transmittance indicator). kubelka_munk_linearity: K-M linearity score (reflectance indicator). baseline_convexity: Convexity of baseline (ATR indicator). # Environmental indicators water_band_variation: Variation in water band region. oh_band_positions: Detected O-H band positions. temperature_sensitivity_score: Score for temperature effect detection. # Scattering indicators scatter_baseline_slope: Wavelength-dependent scatter slope. scatter_baseline_curvature: Curvature from scattering. sample_to_sample_offset_std: Sample-to-sample offset variation. sample_to_sample_slope_std: Sample-to-sample slope variation. # Domain indicators protein_band_intensity: Intensity in protein band regions. carbohydrate_band_intensity: Intensity in carbohydrate regions. lipid_band_intensity: Intensity in lipid band regions. water_band_intensity: Intensity in water band regions. """ name: str = "dataset" n_samples: int = 0 n_wavelengths: int = 0 wavelengths: Optional[np.ndarray] = None # Basic statistics mean_spectrum: Optional[np.ndarray] = None std_spectrum: Optional[np.ndarray] = None global_mean: float = 0.0 global_std: float = 0.0 global_range: Tuple[float, float] = (0.0, 0.0) # Shape properties mean_slope: float = 0.0 slope_std: float = 0.0 slopes: Optional[np.ndarray] = None mean_curvature: float = 0.0 curvature_std: float = 0.0 # Distribution statistics skewness: float = 0.0 kurtosis: float = 0.0 # Noise characteristics noise_estimate: float = 0.0 snr_estimate: float = 0.0 # PCA properties pca_explained_variance: Optional[np.ndarray] = None pca_n_components_95: int = 0 # Peak analysis n_peaks_mean: float = 0.0 peak_positions: Optional[np.ndarray] = None peak_wavenumbers: Optional[np.ndarray] = None # Phase 1-4 Enhanced properties # Instrument indicators effective_resolution: float = 8.0 noise_correlation_length: float = 1.0 wavelength_range: Tuple[float, float] = (1000.0, 2500.0) # Measurement mode indicators baseline_offset: float = 0.0 kubelka_munk_linearity: float = 0.0 baseline_convexity: float = 0.0 # Environmental indicators water_band_variation: float = 0.0 oh_band_positions: Optional[np.ndarray] = None temperature_sensitivity_score: float = 0.0 # Scattering indicators scatter_baseline_slope: float = 0.0 scatter_baseline_curvature: float = 0.0 sample_to_sample_offset_std: float = 0.0 sample_to_sample_slope_std: float = 0.0 # Domain indicators protein_band_intensity: float = 0.0 carbohydrate_band_intensity: float = 0.0 lipid_band_intensity: float = 0.0 water_band_intensity: float = 0.0
[docs] @dataclass class FittedParameters: """ Parameters fitted from real data for synthetic generation. This dataclass contains all parameters needed to configure a SyntheticNIRSGenerator to produce spectra similar to a real dataset, including Phase 1-4 enhanced features. Attributes: # Basic wavelength grid wavelength_start: Start wavelength (nm). wavelength_end: End wavelength (nm). wavelength_step: Wavelength step (nm). # Slope and baseline parameters global_slope_mean: Mean global slope. global_slope_std: Slope standard deviation. baseline_amplitude: Baseline drift amplitude. # Noise parameters noise_base: Base noise level. noise_signal_dep: Signal-dependent noise factor. # Scatter parameters path_length_std: Path length variation. scatter_alpha_std: Multiplicative scatter std. scatter_beta_std: Additive scatter std. tilt_std: Spectral tilt standard deviation. # Complexity complexity: Suggested complexity level. # Source metadata source_name: Name of source dataset. source_properties: Full SpectralProperties of source. # Phase 1-4 Enhanced Parameters # Instrument inference inferred_instrument: Inferred instrument archetype. instrument_inference: Full instrument inference result. # Measurement mode measurement_mode: Inferred measurement mode. measurement_mode_confidence: Confidence of inference. # Domain inference inferred_domain: Inferred application domain. domain_inference: Full domain inference result. # Environmental effects environmental_inference: Environmental effects inference. temperature_config: Suggested temperature config parameters. moisture_config: Suggested moisture config parameters. # Scattering effects scattering_inference: Scattering effects inference. particle_size_config: Suggested particle size config parameters. emsc_config: Suggested EMSC config parameters. # Detected components for procedural generation detected_components: List of detected/inferred component names. suggested_n_components: Suggested number of components. """ # Wavelength grid wavelength_start: float = 1000.0 wavelength_end: float = 2500.0 wavelength_step: float = 2.0 # Slope parameters global_slope_mean: float = 0.0 global_slope_std: float = 0.02 # Noise parameters noise_base: float = 0.001 noise_signal_dep: float = 0.005 # Variation parameters path_length_std: float = 0.05 baseline_amplitude: float = 0.02 scatter_alpha_std: float = 0.05 scatter_beta_std: float = 0.01 tilt_std: float = 0.01 # Metadata complexity: str = "realistic" source_name: str = "" source_properties: Optional[SpectralProperties] = field(default=None, repr=False) # Phase 1-4 Enhanced Parameters # Instrument inference (Phase 2) inferred_instrument: str = "unknown" instrument_inference: Optional[InstrumentInference] = field(default=None, repr=False) # Measurement mode (Phase 2) measurement_mode: str = "transmittance" measurement_mode_confidence: float = 0.0 # Domain inference (Phase 1) inferred_domain: str = "unknown" domain_inference: Optional[DomainInference] = field(default=None, repr=False) # Environmental effects (Phase 3) environmental_inference: Optional[EnvironmentalInference] = field(default=None, repr=False) temperature_config: Dict[str, Any] = field(default_factory=dict) moisture_config: Dict[str, Any] = field(default_factory=dict) # Scattering effects (Phase 3) scattering_inference: Optional[ScatteringInference] = field(default=None, repr=False) particle_size_config: Dict[str, Any] = field(default_factory=dict) emsc_config: Dict[str, Any] = field(default_factory=dict) # Components (Phase 1) detected_components: List[str] = field(default_factory=list) suggested_n_components: int = 5
[docs] def to_generator_kwargs(self) -> Dict[str, Any]: """ Convert fitted parameters to kwargs for SyntheticNIRSGenerator. Returns: Dictionary of keyword arguments. Example: >>> params = fitter.fit(X_real) >>> generator = SyntheticNIRSGenerator(**params.to_generator_kwargs()) """ return { "wavelength_start": self.wavelength_start, "wavelength_end": self.wavelength_end, "wavelength_step": self.wavelength_step, "complexity": self.complexity, }
[docs] def to_full_config(self) -> Dict[str, Any]: """ Convert all fitted parameters to a comprehensive configuration. This includes all Phase 1-4 parameters for complete synthetic data generation matching the source dataset. Returns: Dictionary with all configuration parameters. Example: >>> params = fitter.fit(X_real) >>> config = params.to_full_config() >>> # Use with builder pattern or advanced configuration """ return { # Basic parameters "wavelength_start": self.wavelength_start, "wavelength_end": self.wavelength_end, "wavelength_step": self.wavelength_step, "complexity": self.complexity, # Noise and scatter "noise_base": self.noise_base, "noise_signal_dep": self.noise_signal_dep, "scatter_alpha_std": self.scatter_alpha_std, "scatter_beta_std": self.scatter_beta_std, "path_length_std": self.path_length_std, "baseline_amplitude": self.baseline_amplitude, "tilt_std": self.tilt_std, "global_slope_mean": self.global_slope_mean, "global_slope_std": self.global_slope_std, # Phase 1-4 enhanced "instrument": self.inferred_instrument, "measurement_mode": self.measurement_mode, "domain": self.inferred_domain, "components": self.detected_components, "n_components": self.suggested_n_components, "temperature_config": self.temperature_config, "moisture_config": self.moisture_config, "particle_size_config": self.particle_size_config, "emsc_config": self.emsc_config, }
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert all parameters to a dictionary. Returns: Dictionary with all parameter values. """ return { "wavelength_start": self.wavelength_start, "wavelength_end": self.wavelength_end, "wavelength_step": self.wavelength_step, "global_slope_mean": self.global_slope_mean, "global_slope_std": self.global_slope_std, "noise_base": self.noise_base, "noise_signal_dep": self.noise_signal_dep, "path_length_std": self.path_length_std, "baseline_amplitude": self.baseline_amplitude, "scatter_alpha_std": self.scatter_alpha_std, "scatter_beta_std": self.scatter_beta_std, "tilt_std": self.tilt_std, "complexity": self.complexity, "source_name": self.source_name, # Phase 1-4 enhanced "inferred_instrument": self.inferred_instrument, "measurement_mode": self.measurement_mode, "measurement_mode_confidence": self.measurement_mode_confidence, "inferred_domain": self.inferred_domain, "detected_components": self.detected_components, "suggested_n_components": self.suggested_n_components, "temperature_config": self.temperature_config, "moisture_config": self.moisture_config, "particle_size_config": self.particle_size_config, "emsc_config": self.emsc_config, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "FittedParameters": """ Create FittedParameters from a dictionary. Args: data: Dictionary with parameter values. Returns: FittedParameters instance. """ return cls( wavelength_start=data.get("wavelength_start", 1000.0), wavelength_end=data.get("wavelength_end", 2500.0), wavelength_step=data.get("wavelength_step", 2.0), global_slope_mean=data.get("global_slope_mean", 0.0), global_slope_std=data.get("global_slope_std", 0.02), noise_base=data.get("noise_base", 0.001), noise_signal_dep=data.get("noise_signal_dep", 0.005), path_length_std=data.get("path_length_std", 0.05), baseline_amplitude=data.get("baseline_amplitude", 0.02), scatter_alpha_std=data.get("scatter_alpha_std", 0.05), scatter_beta_std=data.get("scatter_beta_std", 0.01), tilt_std=data.get("tilt_std", 0.01), complexity=data.get("complexity", "realistic"), source_name=data.get("source_name", ""), # Phase 1-4 enhanced inferred_instrument=data.get("inferred_instrument", "unknown"), measurement_mode=data.get("measurement_mode", "transmittance"), measurement_mode_confidence=data.get("measurement_mode_confidence", 0.0), inferred_domain=data.get("inferred_domain", "unknown"), detected_components=data.get("detected_components", []), suggested_n_components=data.get("suggested_n_components", 5), temperature_config=data.get("temperature_config", {}), moisture_config=data.get("moisture_config", {}), particle_size_config=data.get("particle_size_config", {}), emsc_config=data.get("emsc_config", {}), )
[docs] def save(self, path: str) -> None: """ Save parameters to JSON file. Args: path: Output file path. """ import json from pathlib import Path Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: json.dump(self.to_dict(), f, indent=2)
[docs] @classmethod def load(cls, path: str) -> "FittedParameters": """ Load parameters from JSON file. Args: path: Input file path. Returns: FittedParameters instance. """ import json with open(path, "r") as f: data = json.load(f) return cls.from_dict(data)
[docs] def summary(self) -> str: """ Generate a human-readable summary of fitted parameters. Returns: Multi-line summary string. """ lines = [ "=" * 60, f"Fitted Parameters Summary: {self.source_name}", "=" * 60, "", "Wavelength Grid:", f" Range: {self.wavelength_start:.0f} - {self.wavelength_end:.0f} nm", f" Step: {self.wavelength_step:.1f} nm", "", "Noise & Scatter:", f" Base noise: {self.noise_base:.5f}", f" Signal-dependent noise: {self.noise_signal_dep:.5f}", f" Multiplicative scatter (α): {self.scatter_alpha_std:.4f}", f" Additive scatter (β): {self.scatter_beta_std:.4f}", "", "Baseline & Slope:", f" Global slope: {self.global_slope_mean:.4f} ± {self.global_slope_std:.4f}", f" Baseline amplitude: {self.baseline_amplitude:.4f}", f" Path length std: {self.path_length_std:.4f}", "", f"Complexity: {self.complexity}", "", "Phase 1-4 Inferences:", f" Instrument: {self.inferred_instrument}", f" Measurement mode: {self.measurement_mode} " f"(confidence: {self.measurement_mode_confidence:.2f})", f" Domain: {self.inferred_domain}", f" Detected components: {', '.join(self.detected_components[:5]) or 'None'}", f" Suggested n_components: {self.suggested_n_components}", "=" * 60, ] return "\n".join(lines)
[docs] def compute_spectral_properties( X: np.ndarray, wavelengths: Optional[np.ndarray] = None, name: str = "dataset", n_pca_components: int = 20, ) -> SpectralProperties: """ Compute comprehensive spectral properties of a dataset. Analyzes a matrix of spectra to extract statistical and spectral properties useful for fitting and comparison. Includes Phase 1-4 enhanced properties for instrument, mode, domain, and effect inference. Args: X: Spectra matrix (n_samples, n_wavelengths). wavelengths: Optional wavelength grid. name: Dataset identifier. n_pca_components: Maximum PCA components to compute. Returns: SpectralProperties with computed metrics. Example: >>> props = compute_spectral_properties(X_real, wavelengths) >>> print(f"Mean slope: {props.mean_slope:.4f}") >>> print(f"Inferred resolution: {props.effective_resolution:.1f} nm") """ n_samples, n_wavelengths = X.shape if wavelengths is None: wavelengths = np.arange(n_wavelengths) props = SpectralProperties( name=name, n_samples=n_samples, n_wavelengths=n_wavelengths, wavelengths=wavelengths.copy(), wavelength_range=(float(wavelengths.min()), float(wavelengths.max())), ) # Basic statistics props.mean_spectrum = X.mean(axis=0) props.std_spectrum = X.std(axis=0) props.global_mean = float(X.mean()) props.global_std = float(X.std()) props.global_range = (float(X.min()), float(X.max())) # Slope analysis wl_range = np.ptp(wavelengths) if wl_range > 0: x_norm = (wavelengths - wavelengths.min()) / wl_range slopes = [] for i in range(n_samples): coeffs = np.polyfit(x_norm, X[i], 1) # Convert to slope per 1000nm slopes.append(coeffs[0] * 1000.0 / wl_range) props.slopes = np.array(slopes) props.mean_slope = float(np.mean(slopes)) props.slope_std = float(np.std(slopes)) # Curvature analysis window_size = min(21, n_wavelengths // 10 * 2 + 1) if window_size >= 5: curvatures = [] for i in range(min(n_samples, 100)): # Sample subset for speed try: smoothed = savgol_filter(X[i], window_size, 2) d2 = np.gradient(np.gradient(smoothed)) curvatures.append(np.mean(np.abs(d2))) except Exception: pass if curvatures: props.mean_curvature = float(np.mean(curvatures)) props.curvature_std = float(np.std(curvatures)) # Distribution statistics flat_data = X.flatten() props.skewness = float(stats.skew(flat_data)) props.kurtosis = float(stats.kurtosis(flat_data)) # Noise estimation (from first difference) first_diff = np.diff(X, axis=1) props.noise_estimate = float(first_diff.std() / np.sqrt(2)) # SNR estimation signal_power = props.std_spectrum.mean() if props.noise_estimate > 0: props.snr_estimate = float(signal_power / props.noise_estimate) else: props.snr_estimate = float("inf") # PCA analysis try: from sklearn.decomposition import PCA n_comp = min(n_pca_components, n_samples, n_wavelengths) pca = PCA(n_components=n_comp) pca.fit(X) props.pca_explained_variance = pca.explained_variance_ratio_ # Components for 95% variance cumsum = np.cumsum(pca.explained_variance_ratio_) props.pca_n_components_95 = int(np.searchsorted(cumsum, 0.95) + 1) except ImportError: pass # Peak analysis try: window_size = min(21, n_wavelengths // 10 * 2 + 1) if window_size >= 5 and props.std_spectrum is not None: smoothed_mean = savgol_filter(props.mean_spectrum, window_size, 2) prominence = props.std_spectrum.mean() * 0.5 peaks, _ = find_peaks(smoothed_mean, prominence=prominence) props.peak_positions = wavelengths[peaks] if len(peaks) > 0 else np.array([]) props.n_peaks_mean = float(len(peaks)) # Convert to wavenumbers if len(peaks) > 0: props.peak_wavenumbers = 1e7 / wavelengths[peaks] else: props.peak_wavenumbers = np.array([]) except Exception: props.peak_positions = np.array([]) props.peak_wavenumbers = np.array([]) props.n_peaks_mean = 0.0 # ========================================================================= # Phase 1-4 Enhanced Properties # ========================================================================= # Effective resolution estimation (from peak widths) props.effective_resolution = _estimate_spectral_resolution( props.mean_spectrum, wavelengths ) # Noise correlation length (detector indicator) props.noise_correlation_length = _compute_noise_correlation_length(X, wavelengths) # Baseline offset (transmittance indicator) props.baseline_offset = float(props.mean_spectrum.min()) # Baseline convexity (ATR indicator - ATR shows wavelength-dependent penetration) props.baseline_convexity = _compute_baseline_convexity(props.mean_spectrum, wavelengths) # Kubelka-Munk linearity score (reflectance indicator) props.kubelka_munk_linearity = _compute_km_linearity(X) # Sample-to-sample scatter indicators sample_means = X.mean(axis=1) props.sample_to_sample_offset_std = float(np.std(sample_means)) # Sample-to-sample slope variation if props.slopes is not None: props.sample_to_sample_slope_std = float(np.std(props.slopes)) # Scatter baseline analysis props.scatter_baseline_slope, props.scatter_baseline_curvature = \ _analyze_scatter_baseline(props.mean_spectrum, wavelengths) # Water band analysis (environmental indicators) props.water_band_variation, props.oh_band_positions = \ _analyze_water_bands(X, wavelengths) # Temperature sensitivity score props.temperature_sensitivity_score = _compute_temperature_sensitivity(X, wavelengths) # Domain indicators - analyze specific band regions props.protein_band_intensity = _compute_band_intensity( props.mean_spectrum, wavelengths, [(1480, 1560), (2040, 2180)] ) props.carbohydrate_band_intensity = _compute_band_intensity( props.mean_spectrum, wavelengths, [(2050, 2150), (2270, 2350)] ) props.lipid_band_intensity = _compute_band_intensity( props.mean_spectrum, wavelengths, [(1720, 1780), (2300, 2380)] ) props.water_band_intensity = _compute_band_intensity( props.mean_spectrum, wavelengths, [(1400, 1500), (1900, 2000)] ) return props
def _estimate_spectral_resolution( mean_spectrum: np.ndarray, wavelengths: np.ndarray, ) -> float: """Estimate spectral resolution from peak widths.""" try: # Find peaks and measure their widths window_size = min(21, len(wavelengths) // 10 * 2 + 1) if window_size < 5: return 8.0 smoothed = savgol_filter(mean_spectrum, window_size, 2) peaks, properties = find_peaks( smoothed, prominence=np.std(smoothed) * 0.3, width=3, ) if len(peaks) < 2: return 8.0 # Get peak widths at half maximum widths = properties.get("widths", []) if len(widths) == 0: return 8.0 # Convert from indices to wavelength wl_step = np.median(np.diff(wavelengths)) width_nm = np.median(widths) * wl_step return float(np.clip(width_nm, 0.5, 50.0)) except Exception: return 8.0 def _compute_noise_correlation_length( X: np.ndarray, wavelengths: np.ndarray, ) -> float: """Compute correlation length of noise (detector/instrument indicator).""" try: # Get noise by first difference noise = np.diff(X, axis=1) n_samples = min(100, noise.shape[0]) # Compute autocorrelation of noise corr_lengths = [] for i in range(n_samples): n = noise[i] if np.std(n) < 1e-10: continue n = n - n.mean() acf = np.correlate(n, n, mode='full') acf = acf[len(acf)//2:] acf = acf / acf[0] # Find where correlation drops below 1/e below_threshold = np.where(acf < 1/np.e)[0] if len(below_threshold) > 0: corr_lengths.append(below_threshold[0]) else: corr_lengths.append(len(acf)) if len(corr_lengths) == 0: return 1.0 return float(np.median(corr_lengths)) except Exception: return 1.0 def _compute_baseline_convexity( mean_spectrum: np.ndarray, wavelengths: np.ndarray, ) -> float: """Compute baseline convexity (positive for ATR-like, negative for baseline drift).""" try: # Fit a quadratic to the spectrum x_norm = (wavelengths - wavelengths.mean()) / (wavelengths.max() - wavelengths.min()) coeffs = np.polyfit(x_norm, mean_spectrum, 2) # The quadratic coefficient indicates convexity return float(coeffs[0]) except Exception: return 0.0 def _compute_km_linearity(X: np.ndarray) -> float: """ Compute Kubelka-Munk linearity score. Reflectance data converted to K-M should show more linear relationships with concentration than raw reflectance. """ try: # Check if data looks like reflectance (values in 0-1 range mostly) if X.min() < -0.5 or X.max() > 3.0: # Looks like absorbance, not reflectance return 0.0 # Sample some spectra n_samples = min(100, X.shape[0]) indices = np.random.choice(X.shape[0], n_samples, replace=False) X_sample = X[indices] # Compute mean intensity variation mean_intensity = X_sample.mean(axis=1) # Check for linear relationship between mean and std std_intensity = X_sample.std(axis=1) if np.std(mean_intensity) < 1e-10: return 0.0 corr = np.corrcoef(mean_intensity, std_intensity)[0, 1] return float(abs(corr)) if not np.isnan(corr) else 0.0 except Exception: return 0.0 def _analyze_scatter_baseline( mean_spectrum: np.ndarray, wavelengths: np.ndarray, ) -> Tuple[float, float]: """Analyze baseline for scatter effects.""" try: # Fit low-order polynomial to capture scatter baseline x_norm = (wavelengths - wavelengths.mean()) / (wavelengths.max() - wavelengths.min()) coeffs = np.polyfit(x_norm, mean_spectrum, 3) # Slope from linear term baseline_slope = float(coeffs[2]) # Curvature from quadratic term baseline_curvature = float(abs(coeffs[1])) return baseline_slope, baseline_curvature except Exception: return 0.0, 0.0 def _analyze_water_bands( X: np.ndarray, wavelengths: np.ndarray, ) -> Tuple[float, np.ndarray]: """Analyze water band regions for environmental effects.""" try: # Water band regions water_regions = [(1400, 1500), (1900, 2000)] water_variation = 0.0 oh_positions = [] for wl_min, wl_max in water_regions: mask = (wavelengths >= wl_min) & (wavelengths <= wl_max) if not np.any(mask): continue region_data = X[:, mask] # Variation in this region water_variation += float(region_data.std()) # Find peak position in mean spectrum region_mean = region_data.mean(axis=0) region_wl = wavelengths[mask] peak_idx = np.argmax(region_mean) oh_positions.append(region_wl[peak_idx]) return water_variation, np.array(oh_positions) except Exception: return 0.0, np.array([]) def _compute_temperature_sensitivity( X: np.ndarray, wavelengths: np.ndarray, ) -> float: """ Estimate temperature sensitivity based on O-H band variation patterns. Temperature effects cause shifts in O-H bands and changes in hydrogen bonding patterns. """ try: # Check O-H first overtone region (1400-1500 nm) oh_region = (wavelengths >= 1400) & (wavelengths <= 1500) if not np.any(oh_region): return 0.0 region_data = X[:, oh_region] # Temperature effects show up as correlated peak shifts # Compute sample-to-sample variation in peak position peak_positions = [] for i in range(min(100, X.shape[0])): row = region_data[i] if np.std(row) < 1e-10: continue peak_idx = np.argmax(row) peak_positions.append(peak_idx) if len(peak_positions) < 10: return 0.0 # Higher variation in peak position suggests temperature effects position_std = np.std(peak_positions) return float(min(1.0, position_std / 5.0)) except Exception: return 0.0 def _compute_band_intensity( mean_spectrum: np.ndarray, wavelengths: np.ndarray, regions: List[Tuple[float, float]], ) -> float: """Compute mean intensity in specified wavelength regions.""" try: total_intensity = 0.0 n_regions = 0 for wl_min, wl_max in regions: mask = (wavelengths >= wl_min) & (wavelengths <= wl_max) if np.any(mask): total_intensity += float(np.mean(mean_spectrum[mask])) n_regions += 1 return total_intensity / max(n_regions, 1) except Exception: return 0.0
[docs] class RealDataFitter: """ Fit generator parameters to match real dataset properties. This class analyzes real NIRS data and estimates parameters for the SyntheticNIRSGenerator to produce similar spectra. Includes Phase 1-4 enhanced inference for instruments, domains, and effects. Attributes: source_properties: SpectralProperties of the analyzed data. fitted_params: FittedParameters after fitting. Example: >>> fitter = RealDataFitter() >>> params = fitter.fit(X_real, wavelengths=wavelengths) >>> >>> # Access inferred characteristics >>> print(f"Instrument: {params.inferred_instrument}") >>> print(f"Domain: {params.inferred_domain}") >>> >>> # Create matched generator >>> generator = fitter.create_matched_generator() >>> X_synth, _, _ = generator.generate(1000) """ def __init__(self) -> None: """Initialize the fitter.""" self.source_properties: Optional[SpectralProperties] = None self.fitted_params: Optional[FittedParameters] = None self._X_array: Optional[np.ndarray] = None self._wavelengths: Optional[np.ndarray] = None
[docs] def fit( self, X: Union[np.ndarray, "SpectroDataset"], *, wavelengths: Optional[np.ndarray] = None, name: str = "source", infer_instrument: bool = True, infer_domain: bool = True, infer_measurement_mode: bool = True, infer_environmental: bool = True, infer_scattering: bool = True, ) -> FittedParameters: """ Fit generator parameters to real data. Analyzes the input data and estimates optimal parameters for generating synthetic spectra with similar properties. Includes Phase 1-4 enhanced inference. Args: X: Real spectra matrix (n_samples, n_wavelengths) or SpectroDataset. wavelengths: Wavelength grid (required if X is ndarray). name: Dataset name for reference. infer_instrument: Whether to infer instrument archetype. infer_domain: Whether to infer application domain. infer_measurement_mode: Whether to infer measurement mode. infer_environmental: Whether to infer environmental effects. infer_scattering: Whether to infer scattering parameters. Returns: FittedParameters object with estimated parameters. Raises: ValueError: If X is empty or has wrong shape. Example: >>> fitter = RealDataFitter() >>> params = fitter.fit(X_real, wavelengths=wl, name="wheat") >>> print(params.summary()) """ # Handle SpectroDataset input if hasattr(X, "x") and callable(X.x): # It's a SpectroDataset X_array = X.x({}, layout="2d") if wavelengths is None: try: wavelengths = X.wavelengths except (AttributeError, TypeError): wavelengths = np.arange(X_array.shape[1]) if hasattr(X, "name"): name = X.name or name else: X_array = np.asarray(X) # Validate input if X_array.ndim != 2: raise ValueError(f"X must be 2D, got shape {X_array.shape}") if X_array.shape[0] < 5: raise ValueError(f"Need at least 5 samples, got {X_array.shape[0]}") n_samples, n_wavelengths = X_array.shape # Create default wavelengths if not provided if wavelengths is None: wavelengths = np.arange(n_wavelengths) wavelengths = np.asarray(wavelengths) # Store for later use self._X_array = X_array self._wavelengths = wavelengths # Compute spectral properties (includes Phase 1-4 enhanced properties) self.source_properties = compute_spectral_properties( X_array, wavelengths, name ) # Estimate basic parameters params = FittedParameters( source_name=name, source_properties=self.source_properties, ) # Wavelength grid params.wavelength_start = float(wavelengths.min()) params.wavelength_end = float(wavelengths.max()) if len(wavelengths) > 1: params.wavelength_step = float(np.median(np.diff(wavelengths))) # Slope parameters props = self.source_properties params.global_slope_mean = props.mean_slope params.global_slope_std = props.slope_std # Noise parameters params.noise_base = props.noise_estimate * 0.5 params.noise_signal_dep = props.noise_estimate * 0.5 / max(props.global_std, 0.01) # Scatter parameters params.scatter_alpha_std = min(0.15, props.global_std / max(props.global_mean, 0.1) * 0.3) params.scatter_beta_std = props.global_std * 0.1 # Path length variation intensity_variation = np.std(X_array.mean(axis=1)) / max(np.mean(X_array.mean(axis=1)), 0.1) params.path_length_std = min(0.2, intensity_variation * 0.5) # Baseline amplitude params.baseline_amplitude = props.global_std * 0.2 # Tilt standard deviation params.tilt_std = abs(props.mean_slope) * 0.1 # Determine complexity if props.snr_estimate > 50 and props.pca_n_components_95 <= 5: params.complexity = "simple" elif props.snr_estimate < 20 or props.pca_n_components_95 > 15: params.complexity = "complex" else: params.complexity = "realistic" # Suggested number of components params.suggested_n_components = max(3, min(10, props.pca_n_components_95 + 2)) # ===================================================================== # Phase 1-4 Enhanced Inference # ===================================================================== # Instrument inference if infer_instrument: params.instrument_inference = self._infer_instrument(props) params.inferred_instrument = params.instrument_inference.archetype_name # Measurement mode inference if infer_measurement_mode: mode, confidence = self._infer_measurement_mode(X_array, wavelengths, props) params.measurement_mode = mode params.measurement_mode_confidence = confidence # Domain inference if infer_domain: params.domain_inference = self._infer_domain(props) params.inferred_domain = params.domain_inference.domain_name params.detected_components = params.domain_inference.detected_components # Environmental effects inference if infer_environmental: params.environmental_inference = self._infer_environmental(X_array, wavelengths, props) params.temperature_config = self._build_temperature_config(params.environmental_inference) params.moisture_config = self._build_moisture_config(params.environmental_inference) # Scattering inference if infer_scattering: params.scattering_inference = self._infer_scattering(X_array, wavelengths, props) params.particle_size_config = self._build_particle_size_config(params.scattering_inference) params.emsc_config = self._build_emsc_config(params.scattering_inference) self.fitted_params = params return params
def _infer_instrument(self, props: SpectralProperties) -> InstrumentInference: """Infer instrument archetype from spectral properties.""" wl_min, wl_max = props.wavelength_range resolution = props.effective_resolution snr = props.snr_estimate noise_corr = props.noise_correlation_length scores: Dict[str, float] = {} # Score based on wavelength range if wl_max <= 1100: # Short-wave NIR only - likely Si detector scores["scio"] = 0.6 scores["linksquare"] = 0.5 elif wl_min >= 1300 and wl_max <= 2600: # Extended range - likely MEMS FT-NIR scores["neospectra_micro"] = 0.6 scores["siware_neoscanner"] = 0.5 elif wl_max > 2400: # Full NIR range if snr > 30000: scores["foss_xds"] = 0.7 scores["bruker_mpa"] = 0.6 scores["metrohm_ds2500"] = 0.5 elif snr > 10000: scores["unity_spectrastar"] = 0.5 scores["buchi_nirmaster"] = 0.4 else: scores["asd_fieldspec"] = 0.5 elif 900 <= wl_min <= 1000 and 1600 <= wl_max <= 1800: # Standard InGaAs range if resolution < 5: scores["perten_da7200"] = 0.6 else: scores["viavi_micronir"] = 0.6 scores["tellspec"] = 0.5 # Adjust scores based on SNR if snr > 50000: for name in ["bruker_mpa", "thermo_antaris", "foss_xds"]: scores[name] = scores.get(name, 0) + 0.2 elif snr < 5000: for name in ["scio", "neospectra_micro", "innospectra"]: scores[name] = scores.get(name, 0) + 0.2 # Adjust based on resolution if resolution < 2: for name in ["bruker_mpa", "thermo_antaris", "foss_xds"]: scores[name] = scores.get(name, 0) + 0.1 elif resolution > 10: for name in ["scio", "viavi_micronir", "neospectra_micro"]: scores[name] = scores.get(name, 0) + 0.1 # Find best match if scores: best_name = max(scores, key=scores.get) best_score = scores[best_name] else: best_name = "unknown" best_score = 0.0 # Determine detector type if wl_max <= 1100: detector = "si" elif wl_max <= 1700: detector = "ingaas" elif wl_max <= 2600: detector = "ingaas_ext" else: detector = "pbs" return InstrumentInference( archetype_name=best_name, detector_type=detector, wavelength_range=(wl_min, wl_max), estimated_resolution=resolution, confidence=min(1.0, best_score), alternative_archetypes={k: v for k, v in sorted(scores.items(), key=lambda x: -x[1])[:5]}, ) def _infer_measurement_mode( self, X: np.ndarray, wavelengths: np.ndarray, props: SpectralProperties, ) -> Tuple[str, float]: """Infer measurement mode from spectral characteristics.""" scores = { "transmittance": 0.0, "reflectance": 0.0, "transflectance": 0.0, "atr": 0.0, } # Check value range min_val, max_val = props.global_range # Transmittance/absorbance typically has values 0-3+ AU if 0 <= min_val < 0.5 and max_val < 4.0: scores["transmittance"] += 0.3 # Reflectance data often has lower mean absorbance if props.global_mean < 1.0: scores["reflectance"] += 0.2 # ATR shows characteristic wavelength-dependent baseline if props.baseline_convexity > 0.01: scores["atr"] += 0.4 # Kubelka-Munk linearity suggests reflectance if props.kubelka_munk_linearity > 0.5: scores["reflectance"] += 0.3 # Sample-to-sample scatter suggests powder/reflectance if props.sample_to_sample_offset_std > 0.1: scores["reflectance"] += 0.2 # High baseline offset suggests transflectance (double-pass) if props.baseline_offset > 0.5: scores["transflectance"] += 0.2 # Scatter baseline curvature suggests reflectance/powder if props.scatter_baseline_curvature > 0.01: scores["reflectance"] += 0.2 # Find best best_mode = max(scores, key=scores.get) confidence = scores[best_mode] / max(sum(scores.values()), 0.01) return best_mode, float(confidence) def _infer_domain(self, props: SpectralProperties) -> DomainInference: """Infer application domain from spectral features.""" scores: Dict[str, float] = {} detected_components: List[str] = [] # Score based on band intensities water_intensity = props.water_band_intensity protein_intensity = props.protein_band_intensity carb_intensity = props.carbohydrate_band_intensity lipid_intensity = props.lipid_band_intensity # Normalize intensities total = water_intensity + protein_intensity + carb_intensity + lipid_intensity + 1e-10 water_frac = water_intensity / total protein_frac = protein_intensity / total carb_frac = carb_intensity / total lipid_frac = lipid_intensity / total # Agriculture domains if carb_frac > 0.3 and protein_frac > 0.15: scores["agriculture_grain"] = 0.6 detected_components.extend(["starch", "protein"]) if carb_frac > 0.25 and lipid_frac > 0.2: scores["agriculture_oilseeds"] = 0.5 if "lipid" not in detected_components: detected_components.append("lipid") # Food domains if water_frac > 0.4 and protein_frac > 0.1: scores["food_dairy"] = 0.5 scores["food_meat"] = 0.4 if "water" not in detected_components: detected_components.append("water") if lipid_frac > 0.25: scores["food_chocolate"] = 0.4 if "lipid" not in detected_components: detected_components.append("lipid") # Pharmaceutical domains if carb_frac > 0.4 and water_frac < 0.15: scores["pharma_tablets"] = 0.5 detected_components.append("starch") detected_components.append("cellulose") # Environmental domains if protein_frac < 0.1 and carb_frac > 0.2: scores["environmental_soil"] = 0.3 # Beverage domains if water_frac > 0.5: scores["beverage_juice"] = 0.4 scores["beverage_wine"] = 0.3 # Biomedical if water_frac > 0.3 and lipid_frac > 0.15 and protein_frac > 0.15: scores["biomedical_tissue"] = 0.4 # Default fallback if not scores: scores["unknown"] = 0.5 # Find best best_domain = max(scores, key=scores.get) confidence = scores[best_domain] # Determine category category = "unknown" if "agriculture" in best_domain: category = "agriculture" elif "food" in best_domain: category = "food" elif "pharma" in best_domain: category = "pharmaceutical" elif "beverage" in best_domain: category = "beverage" elif "environmental" in best_domain: category = "environmental" elif "biomedical" in best_domain: category = "biomedical" # Remove duplicates from detected components detected_components = list(dict.fromkeys(detected_components)) return DomainInference( domain_name=best_domain, category=category, confidence=confidence, detected_components=detected_components, alternative_domains={k: v for k, v in sorted(scores.items(), key=lambda x: -x[1])[:5]}, ) def _infer_environmental( self, X: np.ndarray, wavelengths: np.ndarray, props: SpectralProperties, ) -> EnvironmentalInference: """Infer environmental effects from spectral patterns.""" # Temperature effects has_temp = props.temperature_sensitivity_score > 0.3 temp_variation = props.temperature_sensitivity_score * 10.0 # rough °C estimate # Moisture effects has_moisture = props.water_band_variation > 0.05 moisture_variation = props.water_band_variation # Water band shift analysis water_shift = 0.0 if props.oh_band_positions is not None and len(props.oh_band_positions) > 0: # Compare to expected free water position (1410 nm) expected_free = 1410 actual = props.oh_band_positions[0] if len(props.oh_band_positions) > 0 else expected_free water_shift = actual - expected_free return EnvironmentalInference( estimated_temperature_variation=temp_variation, has_temperature_effects=has_temp, estimated_moisture_variation=moisture_variation, has_moisture_effects=has_moisture, water_band_shift=water_shift, ) def _infer_scattering( self, X: np.ndarray, wavelengths: np.ndarray, props: SpectralProperties, ) -> ScatteringInference: """Infer scattering effects from spectral patterns.""" # Check for scatter effects has_scatter = ( props.scatter_baseline_curvature > 0.005 or props.sample_to_sample_offset_std > 0.05 or props.sample_to_sample_slope_std > 0.01 ) # Estimate particle size from scattering intensity # Higher scatter curvature suggests smaller particles if props.scatter_baseline_curvature > 0.02: particle_size = 20.0 # Fine powder elif props.scatter_baseline_curvature > 0.01: particle_size = 50.0 # Medium else: particle_size = 100.0 # Coarse # MSC/SNV indicators mult_scatter = props.sample_to_sample_slope_std add_scatter = props.sample_to_sample_offset_std # SNV correctable if high offset variation snv_correctable = add_scatter > 0.05 or mult_scatter > 0.05 # MSC correctable if systematic variation with mean msc_correctable = mult_scatter > 0.03 return ScatteringInference( has_scatter_effects=has_scatter, estimated_particle_size_um=particle_size, multiplicative_scatter_std=mult_scatter, additive_scatter_std=add_scatter, baseline_curvature=props.scatter_baseline_curvature, snv_correctable=snv_correctable, msc_correctable=msc_correctable, ) def _build_temperature_config(self, env: Optional[EnvironmentalInference]) -> Dict[str, Any]: """Build temperature configuration from inference.""" if env is None or not env.has_temperature_effects: return {} return { "temperature_variation": env.estimated_temperature_variation, "enable_shift": True, "enable_intensity": True, "enable_broadening": True, } def _build_moisture_config(self, env: Optional[EnvironmentalInference]) -> Dict[str, Any]: """Build moisture configuration from inference.""" if env is None or not env.has_moisture_effects: return {} # Estimate free water fraction from band shift # Shift towards higher wavelength = more bound water free_fraction = max(0.1, min(0.9, 0.5 - env.water_band_shift / 50.0)) return { "water_activity": 0.5, # Default "moisture_content": env.estimated_moisture_variation / 10.0, "free_water_fraction": free_fraction, } def _build_particle_size_config(self, scatter: Optional[ScatteringInference]) -> Dict[str, Any]: """Build particle size configuration from inference.""" if scatter is None or not scatter.has_scatter_effects: return {} return { "mean_size_um": scatter.estimated_particle_size_um, "std_size_um": scatter.estimated_particle_size_um * 0.3, "size_effect_strength": 1.0, } def _build_emsc_config(self, scatter: Optional[ScatteringInference]) -> Dict[str, Any]: """Build EMSC configuration from inference.""" if scatter is None: return {} return { "multiplicative_scatter_std": scatter.multiplicative_scatter_std, "additive_scatter_std": scatter.additive_scatter_std, "polynomial_order": 2, "include_wavelength_terms": True, }
[docs] def create_matched_generator( self, random_state: Optional[int] = None, ) -> "SyntheticNIRSGenerator": """ Create a SyntheticNIRSGenerator configured to match the fitted data. This method creates a generator with all fitted parameters including Phase 1-4 enhanced features (instrument, domain, effects). Args: random_state: Random seed for reproducibility. Returns: Configured SyntheticNIRSGenerator instance. Raises: RuntimeError: If fit() hasn't been called. Example: >>> fitter = RealDataFitter() >>> params = fitter.fit(X_real, wavelengths=wavelengths) >>> generator = fitter.create_matched_generator(random_state=42) >>> X_synth, _, _ = generator.generate(1000) """ if self.fitted_params is None: raise RuntimeError("Must call fit() before create_matched_generator()") from .generator import SyntheticNIRSGenerator params = self.fitted_params generator = SyntheticNIRSGenerator( wavelength_start=params.wavelength_start, wavelength_end=params.wavelength_end, wavelength_step=params.wavelength_step, complexity=params.complexity, random_state=random_state, ) return generator
[docs] def fit_from_path( self, path: str, *, name: Optional[str] = None, ) -> FittedParameters: """ Fit parameters from a dataset path. Loads data using DatasetConfigs and fits parameters. Args: path: Path to dataset folder. name: Optional name override. Returns: FittedParameters object. Example: >>> params = fitter.fit_from_path("sample_data/regression") """ from nirs4all.data import DatasetConfigs dataset_config = DatasetConfigs(path) datasets = dataset_config.get_datasets() if not datasets: raise ValueError(f"No datasets found at {path}") dataset = datasets[0] X = dataset.x({}, layout="2d") # Try to get wavelengths wavelengths = None try: wavelengths = dataset.wavelengths except (AttributeError, TypeError): pass return self.fit(X, wavelengths=wavelengths, name=name or dataset.name)
[docs] def evaluate_similarity( self, X_synthetic: np.ndarray, wavelengths: Optional[np.ndarray] = None, ) -> Dict[str, Any]: """ Evaluate similarity between synthetic and source data. Computes various metrics comparing synthetic spectra to the original real data. Args: X_synthetic: Synthetic spectra matrix. wavelengths: Optional wavelength grid. Returns: Dictionary with similarity metrics. Raises: RuntimeError: If fit() hasn't been called. Example: >>> params = fitter.fit(X_real) >>> X_synth, _, _ = generator.generate(1000) >>> metrics = fitter.evaluate_similarity(X_synth) >>> print(f"Similarity: {metrics['overall_score']:.1f}/100") """ if self.source_properties is None: raise RuntimeError("Must call fit() before evaluate_similarity()") # Use source wavelengths if not provided if wavelengths is None and self.source_properties.wavelengths is not None: # Assume same wavelength grid wavelengths = self.source_properties.wavelengths # Compute synthetic properties synth_props = compute_spectral_properties( X_synthetic, wavelengths, "synthetic" ) real_props = self.source_properties metrics: Dict[str, Any] = {} # Mean comparison if real_props.global_mean != 0: metrics["mean_rel_diff"] = ( (synth_props.global_mean - real_props.global_mean) / abs(real_props.global_mean) ) else: metrics["mean_rel_diff"] = synth_props.global_mean # Std comparison if real_props.global_std != 0: metrics["std_rel_diff"] = ( (synth_props.global_std - real_props.global_std) / real_props.global_std ) else: metrics["std_rel_diff"] = synth_props.global_std # Slope comparison metrics["slope_diff"] = synth_props.mean_slope - real_props.mean_slope if real_props.mean_slope != 0: metrics["slope_ratio"] = synth_props.mean_slope / real_props.mean_slope else: metrics["slope_ratio"] = float("inf") # Noise comparison if real_props.noise_estimate != 0: metrics["noise_ratio"] = synth_props.noise_estimate / real_props.noise_estimate else: metrics["noise_ratio"] = float("inf") # SNR comparison if real_props.snr_estimate != 0 and real_props.snr_estimate != float("inf"): metrics["snr_ratio"] = synth_props.snr_estimate / real_props.snr_estimate else: metrics["snr_ratio"] = float("inf") # PCA complexity metrics["pca_complexity_diff"] = ( synth_props.pca_n_components_95 - real_props.pca_n_components_95 ) # Mean spectrum correlation (if wavelengths match) if (real_props.n_wavelengths == synth_props.n_wavelengths and real_props.mean_spectrum is not None and synth_props.mean_spectrum is not None): corr = np.corrcoef( real_props.mean_spectrum, synth_props.mean_spectrum )[0, 1] metrics["mean_spectrum_correlation"] = float(corr) # Slope distribution comparison if real_props.slopes is not None and synth_props.slopes is not None: ks_stat, ks_pval = stats.ks_2samp(real_props.slopes, synth_props.slopes) metrics["slope_ks_statistic"] = float(ks_stat) metrics["slope_ks_pvalue"] = float(ks_pval) # Overall similarity score (0-100) scores = [] if "mean_rel_diff" in metrics: scores.append(max(0, 100 - abs(metrics["mean_rel_diff"]) * 100)) if "std_rel_diff" in metrics: scores.append(max(0, 100 - abs(metrics["std_rel_diff"]) * 100)) if "noise_ratio" in metrics and metrics["noise_ratio"] != float("inf"): scores.append(max(0, 100 - abs(1 - metrics["noise_ratio"]) * 100)) if "mean_spectrum_correlation" in metrics: scores.append(metrics["mean_spectrum_correlation"] * 100) metrics["overall_score"] = float(np.mean(scores)) if scores else 0.0 return metrics
[docs] def get_tuning_recommendations(self) -> List[str]: """ Get recommendations for tuning generation parameters. Based on the fitted parameters and source data, provides suggestions for manual tuning. Returns: List of recommendation strings. Example: >>> params = fitter.fit(X_real) >>> for rec in fitter.get_tuning_recommendations(): ... print(f"- {rec}") """ if self.source_properties is None or self.fitted_params is None: return ["Call fit() first to analyze data."] recs = [] props = self.source_properties params = self.fitted_params # Noise recommendations if props.snr_estimate < 15: recs.append( f"High noise detected (SNR={props.snr_estimate:.1f}). " f"Using noise_base={params.noise_base:.4f}" ) elif props.snr_estimate > 100: recs.append( f"Very low noise detected (SNR={props.snr_estimate:.1f}). " "Consider using 'simple' complexity for faster generation." ) # Slope recommendations if abs(props.mean_slope) > 0.1: recs.append( f"Significant slope detected ({props.mean_slope:.3f}/1000nm). " "Ensure global_slope_mean is correctly set." ) # Complexity recommendations if props.pca_n_components_95 > 10: recs.append( f"High complexity ({props.pca_n_components_95} PCA components for 95%). " "Consider using more spectral components." ) elif props.pca_n_components_95 <= 3: recs.append( f"Low complexity ({props.pca_n_components_95} PCA components). " "Simple mode may be sufficient." ) # Variation recommendations if params.path_length_std > 0.15: recs.append( f"High sample-to-sample variation detected. " f"path_length_std set to {params.path_length_std:.3f}" ) return recs
[docs] def fit_to_real_data( X: Union[np.ndarray, "SpectroDataset"], wavelengths: Optional[np.ndarray] = None, name: str = "source", ) -> FittedParameters: """ Quick function to fit parameters to real data. Convenience function for simple fitting use cases. Args: X: Real spectra or SpectroDataset. wavelengths: Wavelength grid. name: Dataset name. Returns: FittedParameters object. Example: >>> params = fit_to_real_data(X_real, wavelengths) >>> generator = SyntheticNIRSGenerator(**params.to_generator_kwargs()) """ fitter = RealDataFitter() return fitter.fit(X, wavelengths=wavelengths, name=name)
[docs] def compare_datasets( X_synthetic: np.ndarray, X_real: np.ndarray, wavelengths: Optional[np.ndarray] = None, ) -> Dict[str, Any]: """ Quick comparison between synthetic and real datasets. Args: X_synthetic: Synthetic spectra. X_real: Real spectra. wavelengths: Wavelength grid. Returns: Dictionary with comparison metrics. Example: >>> metrics = compare_datasets(X_synth, X_real) >>> print(f"Similarity: {metrics['overall_score']:.1f}/100") """ fitter = RealDataFitter() fitter.fit(X_real, wavelengths=wavelengths, name="real") return fitter.evaluate_similarity(X_synthetic, wavelengths)