Source code for nirs4all.data.synthetic.reconstruction.validation

"""
Validation and diagnostics for reconstruction quality.

Provides tools to evaluate:
1. Reconstruction quality (residual analysis)
2. Synthetic vs real data comparison (PCA, statistics)
3. Parameter plausibility checks
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy import stats


# =============================================================================
# Validation Result
# =============================================================================


[docs] @dataclass class ValidationResult: """ Result of reconstruction validation. Attributes: reconstruction_metrics: Per-sample reconstruction quality. synthetic_metrics: Synthetic vs real comparison metrics. parameter_metrics: Parameter plausibility metrics. overall_score: Combined quality score (0-100). passed: Whether all quality checks passed. warnings: List of warning messages. """ reconstruction_metrics: Dict[str, Any] = field(default_factory=dict) synthetic_metrics: Dict[str, Any] = field(default_factory=dict) parameter_metrics: Dict[str, Any] = field(default_factory=dict) overall_score: float = 0.0 passed: bool = False warnings: List[str] = field(default_factory=list)
[docs] def summary(self) -> str: """Generate human-readable summary.""" lines = [ "=" * 70, "Reconstruction Validation Summary", "=" * 70, f"Overall Score: {self.overall_score:.1f}/100", f"Status: {'PASSED' if self.passed else 'NEEDS REVIEW'}", "", ] if self.reconstruction_metrics: lines.append("Reconstruction Quality:") for k, v in self.reconstruction_metrics.items(): if isinstance(v, float): lines.append(f" {k}: {v:.4f}") else: lines.append(f" {k}: {v}") if self.synthetic_metrics: lines.append("") lines.append("Synthetic vs Real Comparison:") for k, v in self.synthetic_metrics.items(): if isinstance(v, float): lines.append(f" {k}: {v:.4f}") else: lines.append(f" {k}: {v}") if self.warnings: lines.append("") lines.append("Warnings:") for w in self.warnings: lines.append(f" - {w}") lines.append("=" * 70) return "\n".join(lines)
# ============================================================================= # Reconstruction Validator # =============================================================================
[docs] @dataclass class ReconstructionValidator: """ Validate reconstruction quality and synthetic realism. Checks: 1. Residuals should be structureless (no systematic patterns) 2. Synthetic should match real in PCA space 3. Per-wavelength statistics should be similar 4. Parameters should be physically plausible Attributes: r2_threshold: Minimum acceptable R² for reconstruction. residual_autocorr_threshold: Max autocorrelation in residuals. pca_distance_threshold: Max Mahalanobis distance in PCA space. concentration_max: Max plausible concentration value. """ r2_threshold: float = 0.90 residual_autocorr_threshold: float = 0.3 pca_distance_threshold: float = 3.0 concentration_max: float = 10.0 path_length_bounds: Tuple[float, float] = (0.3, 3.0)
[docs] def validate_reconstruction( self, inversion_results: List["InversionResult"], ) -> Dict[str, Any]: """ Validate reconstruction quality. Args: inversion_results: List of inversion results. Returns: Dict of reconstruction metrics. """ metrics = {} # R² statistics r2_values = np.array([r.r_squared for r in inversion_results]) metrics["mean_r2"] = float(np.mean(r2_values)) metrics["min_r2"] = float(np.min(r2_values)) metrics["r2_above_threshold"] = float(np.mean(r2_values >= self.r2_threshold)) # RMSE statistics rmse_values = np.array([r.rmse for r in inversion_results]) metrics["mean_rmse"] = float(np.mean(rmse_values)) metrics["median_rmse"] = float(np.median(rmse_values)) # Residual analysis (check for systematic patterns) all_residuals = [] for r in inversion_results: if r.residuals is not None: all_residuals.append(r.residuals) if len(all_residuals) > 0: residuals = np.array(all_residuals) # Mean residual (should be near zero) mean_residual = np.mean(residuals, axis=0) metrics["mean_residual_magnitude"] = float(np.mean(np.abs(mean_residual))) # Residual autocorrelation (should be low) autocorrs = [] for res in residuals[:min(50, len(residuals))]: ac = np.corrcoef(res[:-1], res[1:])[0, 1] if not np.isnan(ac): autocorrs.append(ac) if autocorrs: metrics["residual_autocorr"] = float(np.mean(autocorrs)) # Check for oscillatory patterns (high-frequency energy) fft_energy = [] for res in residuals[:min(50, len(residuals))]: fft = np.fft.rfft(res) high_freq = np.sum(np.abs(fft[len(fft)//2:])**2) total = np.sum(np.abs(fft)**2) + 1e-10 fft_energy.append(high_freq / total) metrics["high_freq_residual_energy"] = float(np.mean(fft_energy)) return metrics
[docs] def validate_synthetic( self, X_real: np.ndarray, X_synth: np.ndarray, ) -> Dict[str, Any]: """ Validate synthetic vs real data. Args: X_real: Real data matrix. X_synth: Synthetic data matrix. Returns: Dict of comparison metrics. """ metrics = {} # Basic statistics comparison real_mean = X_real.mean(axis=0) synth_mean = X_synth.mean(axis=0) real_std = X_real.std(axis=0) synth_std = X_synth.std(axis=0) # Correlation of mean spectra mean_corr = np.corrcoef(real_mean, synth_mean)[0, 1] metrics["mean_spectrum_correlation"] = float(mean_corr) # Correlation of std spectra std_corr = np.corrcoef(real_std, synth_std)[0, 1] metrics["std_spectrum_correlation"] = float(std_corr) # Relative mean difference rel_mean_diff = np.mean(np.abs(real_mean - synth_mean) / (np.abs(real_mean) + 1e-10)) metrics["relative_mean_difference"] = float(rel_mean_diff) # Relative std difference rel_std_diff = np.mean(np.abs(real_std - synth_std) / (real_std + 1e-10)) metrics["relative_std_difference"] = float(rel_std_diff) # PCA comparison try: from sklearn.decomposition import PCA n_comp = min(10, X_real.shape[0] - 1, X_synth.shape[0] - 1, X_real.shape[1]) pca = PCA(n_components=n_comp) scores_real = pca.fit_transform(X_real) scores_synth = pca.transform(X_synth) # Compare score distributions for i in range(min(3, n_comp)): ks_stat, ks_pvalue = stats.ks_2samp(scores_real[:, i], scores_synth[:, i]) metrics[f"pca_score_{i+1}_ks_stat"] = float(ks_stat) metrics[f"pca_score_{i+1}_ks_pvalue"] = float(ks_pvalue) # Variance explained comparison metrics["pca_variance_real"] = pca.explained_variance_ratio_[:3].tolist() except Exception as e: metrics["pca_error"] = str(e) # Simple discriminator (can synthetic be distinguished from real?) try: from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_score # Combine and label X_combined = np.vstack([X_real, X_synth]) y = np.concatenate([np.zeros(len(X_real)), np.ones(len(X_synth))]) # Use PCA features to avoid overfitting if "pca_error" not in metrics: X_pca = pca.transform(X_combined)[:, :min(5, n_comp)] else: X_pca = X_combined[:, ::max(1, X_combined.shape[1] // 20)] # Subsample wavelengths # Cross-validated accuracy clf = LogisticRegression(max_iter=1000, random_state=42) cv_scores = cross_val_score(clf, X_pca, y, cv=min(5, len(X_real))) metrics["discriminator_accuracy"] = float(np.mean(cv_scores)) # Accuracy close to 0.5 = good (indistinguishable) except Exception as e: metrics["discriminator_error"] = str(e) return metrics
[docs] def validate_parameters( self, inversion_results: List["InversionResult"], ) -> Dict[str, Any]: """ Validate parameter plausibility. Args: inversion_results: List of inversion results. Returns: Dict of parameter metrics. """ metrics = {} warnings = [] # Concentration statistics concentrations = np.array([r.concentrations for r in inversion_results]) metrics["concentration_mean"] = float(np.mean(concentrations)) metrics["concentration_max"] = float(np.max(concentrations)) metrics["concentration_negative_frac"] = float(np.mean(concentrations < 0)) if np.max(concentrations) > self.concentration_max: warnings.append(f"Max concentration ({np.max(concentrations):.2f}) exceeds threshold") if np.mean(concentrations < 0) > 0.1: warnings.append(f"Significant negative concentrations ({np.mean(concentrations < 0)*100:.1f}%)") # Path length statistics path_lengths = np.array([r.path_length for r in inversion_results]) metrics["path_length_mean"] = float(np.mean(path_lengths)) metrics["path_length_std"] = float(np.std(path_lengths)) out_of_bounds = (path_lengths < self.path_length_bounds[0]) | (path_lengths > self.path_length_bounds[1]) metrics["path_length_out_of_bounds_frac"] = float(np.mean(out_of_bounds)) if np.mean(out_of_bounds) > 0.1: warnings.append(f"Path lengths out of bounds ({np.mean(out_of_bounds)*100:.1f}%)") # Wavelength shift statistics wl_shifts = np.array([r.wl_shift_residual for r in inversion_results]) metrics["wl_shift_mean"] = float(np.mean(wl_shifts)) metrics["wl_shift_std"] = float(np.std(wl_shifts)) metrics["warnings"] = warnings return metrics
[docs] def validate( self, inversion_results: List["InversionResult"], X_real: np.ndarray, X_synth: np.ndarray, ) -> ValidationResult: """ Run full validation. Args: inversion_results: Inversion results. X_real: Real data. X_synth: Synthetic data. Returns: ValidationResult. """ recon_metrics = self.validate_reconstruction(inversion_results) synth_metrics = self.validate_synthetic(X_real, X_synth) param_metrics = self.validate_parameters(inversion_results) warnings = param_metrics.pop("warnings", []) # Compute overall score score = 0.0 checks_passed = 0 total_checks = 0 # Reconstruction checks if recon_metrics.get("mean_r2", 0) >= self.r2_threshold: score += 25 checks_passed += 1 total_checks += 1 # Synthetic checks if synth_metrics.get("mean_spectrum_correlation", 0) >= 0.95: score += 25 checks_passed += 1 total_checks += 1 if synth_metrics.get("discriminator_accuracy", 1.0) <= 0.7: score += 25 checks_passed += 1 total_checks += 1 # Parameter checks if param_metrics.get("concentration_negative_frac", 1.0) <= 0.05: score += 12.5 checks_passed += 0.5 if param_metrics.get("path_length_out_of_bounds_frac", 1.0) <= 0.1: score += 12.5 checks_passed += 0.5 total_checks += 1 passed = checks_passed >= total_checks * 0.75 return ValidationResult( reconstruction_metrics=recon_metrics, synthetic_metrics=synth_metrics, parameter_metrics=param_metrics, overall_score=score, passed=passed, warnings=warnings, )
# ============================================================================= # Diagnostic Plots (Data Generation) # =============================================================================
[docs] def compute_diagnostic_data( X_real: np.ndarray, X_synth: np.ndarray, inversion_results: Optional[List["InversionResult"]] = None, wavelengths: Optional[np.ndarray] = None, ) -> Dict[str, Any]: """ Compute data for diagnostic plots. Args: X_real: Real data. X_synth: Synthetic data. inversion_results: Optional inversion results for residuals. wavelengths: Wavelength grid. Returns: Dict of diagnostic data arrays. """ if wavelengths is None: wavelengths = np.arange(X_real.shape[1]) data = { "wavelengths": wavelengths, } # Mean and std spectra data["real_mean"] = X_real.mean(axis=0) data["real_std"] = X_real.std(axis=0) data["synth_mean"] = X_synth.mean(axis=0) data["synth_std"] = X_synth.std(axis=0) # Residuals if inversion_results: residuals = np.array([ r.residuals for r in inversion_results if r.residuals is not None ]) if len(residuals) > 0: data["residual_mean"] = residuals.mean(axis=0) data["residual_std"] = residuals.std(axis=0) # PCA try: from sklearn.decomposition import PCA n_comp = min(5, X_real.shape[0] - 1, X_synth.shape[0] - 1, X_real.shape[1]) pca = PCA(n_components=n_comp) data["real_pca_scores"] = pca.fit_transform(X_real) data["synth_pca_scores"] = pca.transform(X_synth) data["pca_explained_variance"] = pca.explained_variance_ratio_ data["pca_components"] = pca.components_ except Exception: pass return data