"""
Validation utilities for synthetic data generation.
This module provides functions to validate generated synthetic data
for correctness and expected properties, including spectral realism
scoring for comparing synthetic data against real NIRS spectra.
Phase 4 Features:
- Spectral realism scorecard with quantitative metrics
- Correlation length analysis
- Derivative statistics comparison
- Peak density analysis
- Baseline curvature metrics
- SNR distribution analysis
- Adversarial validation (classifier distinguishability)
References:
- Engel, J., et al. (2013). Breaking with trends in pre-processing?
TrAC Trends in Analytical Chemistry, 50, 96-106.
- Rinnan, Å., et al. (2009). Review of the most common pre-processing
techniques for near-infrared spectra. TrAC Trends in Analytical Chemistry.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from scipy import signal
from scipy.stats import ks_2samp, wasserstein_distance
from scipy.ndimage import gaussian_filter1d
[docs]
class ValidationError(Exception):
"""Exception raised when synthetic data validation fails."""
pass
[docs]
def validate_spectra(
X: np.ndarray,
expected_shape: Optional[Tuple[int, int]] = None,
check_finite: bool = True,
check_positive: bool = False,
value_range: Optional[Tuple[float, float]] = None,
) -> List[str]:
"""
Validate generated spectra matrix.
Args:
X: Spectra matrix to validate.
expected_shape: Expected (n_samples, n_wavelengths) shape.
check_finite: Whether to check for NaN/Inf values.
check_positive: Whether to require all positive values.
value_range: Optional (min, max) expected range.
Returns:
List of validation warning messages (empty if all OK).
Raises:
ValidationError: If critical validation fails.
Example:
>>> X = np.random.randn(100, 500)
>>> warnings = validate_spectra(X, expected_shape=(100, 500))
>>> if warnings:
... print("Warnings:", warnings)
"""
warnings: List[str] = []
# Check type
if not isinstance(X, np.ndarray):
raise ValidationError(f"Expected numpy array, got {type(X).__name__}")
# Check dimensions
if X.ndim != 2:
raise ValidationError(f"Expected 2D array, got {X.ndim}D")
# Check shape
if expected_shape is not None:
if X.shape != expected_shape:
raise ValidationError(
f"Shape mismatch: expected {expected_shape}, got {X.shape}"
)
# Check finite values
if check_finite:
n_nan = np.isnan(X).sum()
n_inf = np.isinf(X).sum()
if n_nan > 0:
raise ValidationError(f"Found {n_nan} NaN values in spectra")
if n_inf > 0:
raise ValidationError(f"Found {n_inf} Inf values in spectra")
# Check positive values
if check_positive:
n_negative = (X < 0).sum()
if n_negative > 0:
warnings.append(
f"Found {n_negative} negative values ({100*n_negative/X.size:.2f}%)"
)
# Check value range
if value_range is not None:
min_val, max_val = value_range
if X.min() < min_val:
warnings.append(
f"Minimum value {X.min():.4f} below expected {min_val}"
)
if X.max() > max_val:
warnings.append(
f"Maximum value {X.max():.4f} above expected {max_val}"
)
return warnings
[docs]
def validate_concentrations(
C: np.ndarray,
n_samples: Optional[int] = None,
n_components: Optional[int] = None,
check_normalized: bool = False,
tolerance: float = 0.01,
) -> List[str]:
"""
Validate concentration matrix.
Args:
C: Concentration matrix to validate.
n_samples: Expected number of samples.
n_components: Expected number of components.
check_normalized: Whether concentrations should sum to 1.
tolerance: Tolerance for normalization check.
Returns:
List of validation warning messages.
Raises:
ValidationError: If critical validation fails.
"""
warnings: List[str] = []
if not isinstance(C, np.ndarray):
raise ValidationError(f"Expected numpy array, got {type(C).__name__}")
if C.ndim != 2:
raise ValidationError(f"Expected 2D concentration matrix, got {C.ndim}D")
if n_samples is not None and C.shape[0] != n_samples:
raise ValidationError(
f"Expected {n_samples} samples, got {C.shape[0]}"
)
if n_components is not None and C.shape[1] != n_components:
raise ValidationError(
f"Expected {n_components} components, got {C.shape[1]}"
)
# Check for negative concentrations
n_negative = (C < 0).sum()
if n_negative > 0:
warnings.append(f"Found {n_negative} negative concentration values")
# Check normalization
if check_normalized:
row_sums = C.sum(axis=1)
deviations = np.abs(row_sums - 1.0)
if deviations.max() > tolerance:
warnings.append(
f"Concentrations not normalized: max deviation = {deviations.max():.4f}"
)
return warnings
[docs]
def validate_wavelengths(
wavelengths: np.ndarray,
expected_range: Optional[Tuple[float, float]] = None,
check_monotonic: bool = True,
check_uniform: bool = True,
) -> List[str]:
"""
Validate wavelength array.
Args:
wavelengths: Wavelength array to validate.
expected_range: Optional (min, max) expected range in nm.
check_monotonic: Whether to check for monotonically increasing values.
check_uniform: Whether to check for uniform spacing.
Returns:
List of validation warning messages.
Raises:
ValidationError: If critical validation fails.
"""
warnings: List[str] = []
if not isinstance(wavelengths, np.ndarray):
raise ValidationError(f"Expected numpy array, got {type(wavelengths).__name__}")
if wavelengths.ndim != 1:
raise ValidationError(f"Expected 1D wavelength array, got {wavelengths.ndim}D")
if len(wavelengths) < 2:
raise ValidationError(
f"Wavelength array too short: {len(wavelengths)} points"
)
# Check range
if expected_range is not None:
min_wl, max_wl = expected_range
if wavelengths.min() < min_wl or wavelengths.max() > max_wl:
warnings.append(
f"Wavelength range [{wavelengths.min():.1f}, {wavelengths.max():.1f}] "
f"outside expected [{min_wl}, {max_wl}]"
)
# Check monotonic
if check_monotonic:
diffs = np.diff(wavelengths)
if not np.all(diffs > 0):
raise ValidationError("Wavelengths must be monotonically increasing")
# Check uniform spacing
if check_uniform:
diffs = np.diff(wavelengths)
if diffs.std() / diffs.mean() > 0.01: # 1% tolerance
warnings.append("Wavelength spacing is not uniform")
return warnings
[docs]
def validate_synthetic_output(
X: np.ndarray,
C: np.ndarray,
E: np.ndarray,
wavelengths: Optional[np.ndarray] = None,
) -> List[str]:
"""
Validate complete synthetic generation output.
Args:
X: Generated spectra (n_samples, n_wavelengths).
C: Concentration matrix (n_samples, n_components).
E: Component spectra (n_components, n_wavelengths).
wavelengths: Optional wavelength array.
Returns:
List of all validation warnings.
Raises:
ValidationError: If critical validation fails.
Example:
>>> from nirs4all.data.synthetic import SyntheticNIRSGenerator
>>> gen = SyntheticNIRSGenerator(random_state=42)
>>> X, C, E = gen.generate(100)
>>> warnings = validate_synthetic_output(X, C, E, gen.wavelengths)
"""
all_warnings: List[str] = []
n_samples, n_wavelengths = X.shape
n_components = C.shape[1]
# Validate spectra
all_warnings.extend(
validate_spectra(X, expected_shape=(n_samples, n_wavelengths))
)
# Validate concentrations
all_warnings.extend(
validate_concentrations(C, n_samples=n_samples, n_components=n_components)
)
# Validate component spectra shape
if E.shape != (n_components, n_wavelengths):
raise ValidationError(
f"Component spectra shape mismatch: expected "
f"({n_components}, {n_wavelengths}), got {E.shape}"
)
# Validate wavelengths if provided
if wavelengths is not None:
all_warnings.extend(validate_wavelengths(wavelengths))
if len(wavelengths) != n_wavelengths:
raise ValidationError(
f"Wavelength array length {len(wavelengths)} does not match "
f"spectra width {n_wavelengths}"
)
return all_warnings
# ============================================================================
# Phase 4: Spectral Realism Scorecard
# ============================================================================
[docs]
class RealismMetric(str, Enum):
"""Metrics used in the spectral realism scorecard."""
CORRELATION_LENGTH = "correlation_length"
DERIVATIVE_STATISTICS = "derivative_statistics"
PEAK_DENSITY = "peak_density"
BASELINE_CURVATURE = "baseline_curvature"
SNR_DISTRIBUTION = "snr_distribution"
ADVERSARIAL_AUC = "adversarial_auc"
[docs]
@dataclass
class MetricResult:
"""
Result of a single realism metric evaluation.
Attributes:
metric: The metric type.
value: The computed metric value.
threshold: The threshold for passing.
passed: Whether the metric passed the threshold.
details: Additional details about the metric computation.
"""
metric: RealismMetric
value: float
threshold: float
passed: bool
details: Dict[str, Any] = field(default_factory=dict)
def __repr__(self) -> str:
status = "✓" if self.passed else "✗"
return f"{status} {self.metric.value}: {self.value:.4f} (threshold: {self.threshold:.4f})"
[docs]
@dataclass
class SpectralRealismScore:
"""
Complete spectral realism assessment results.
This dataclass contains the results of comparing synthetic spectra
against real spectra using multiple quantitative metrics.
Attributes:
correlation_length_overlap: Distribution overlap for autocorrelation decay [0-1].
derivative_ks_pvalue: p-value from KS test on derivative distributions.
peak_density_ratio: Ratio of synthetic to real peak densities.
baseline_curvature_overlap: Distribution overlap for baseline curvature [0-1].
snr_magnitude_match: Whether SNR is within one order of magnitude.
adversarial_auc: AUC of classifier trying to distinguish real from synthetic.
overall_pass: Whether all critical metrics pass.
metric_results: Individual metric results with details.
warnings: Any warnings from the analysis.
Example:
>>> score = compute_spectral_realism_scorecard(real_spectra, synthetic_spectra, wavelengths)
>>> print(f"Overall pass: {score.overall_pass}")
>>> print(f"Adversarial AUC: {score.adversarial_auc:.3f}")
>>> for metric in score.metric_results:
... print(metric)
"""
correlation_length_overlap: float
derivative_ks_pvalue: float
peak_density_ratio: float
baseline_curvature_overlap: float
snr_magnitude_match: bool
adversarial_auc: float
overall_pass: bool
metric_results: List[MetricResult] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"correlation_length_overlap": self.correlation_length_overlap,
"derivative_ks_pvalue": self.derivative_ks_pvalue,
"peak_density_ratio": self.peak_density_ratio,
"baseline_curvature_overlap": self.baseline_curvature_overlap,
"snr_magnitude_match": self.snr_magnitude_match,
"adversarial_auc": self.adversarial_auc,
"overall_pass": self.overall_pass,
"warnings": self.warnings,
}
[docs]
def summary(self) -> str:
"""Return a human-readable summary of the score."""
lines = [
"=" * 60,
"Spectral Realism Scorecard",
"=" * 60,
]
for result in self.metric_results:
lines.append(str(result))
lines.append("-" * 60)
overall_status = "PASS ✓" if self.overall_pass else "FAIL ✗"
lines.append(f"Overall: {overall_status}")
if self.warnings:
lines.append("\nWarnings:")
for w in self.warnings:
lines.append(f" - {w}")
lines.append("=" * 60)
return "\n".join(lines)
# ============================================================================
# Realism Metric Functions
# ============================================================================
[docs]
def compute_correlation_length(
spectra: np.ndarray,
max_lag: int = 50,
) -> np.ndarray:
"""
Compute correlation lengths for a set of spectra.
The correlation length is the lag at which the autocorrelation
function decays to 1/e of its initial value.
Args:
spectra: Array of shape (n_samples, n_wavelengths).
max_lag: Maximum lag to compute autocorrelation for.
Returns:
Array of correlation lengths for each spectrum.
Example:
>>> X = np.random.randn(100, 500)
>>> lengths = compute_correlation_length(X)
>>> print(f"Mean correlation length: {lengths.mean():.2f}")
"""
n_samples, n_wavelengths = spectra.shape
max_lag = min(max_lag, n_wavelengths // 4)
correlation_lengths = np.zeros(n_samples)
for i in range(n_samples):
spectrum = spectra[i] - spectra[i].mean()
if np.std(spectrum) < 1e-10:
correlation_lengths[i] = 0
continue
# Compute normalized autocorrelation using FFT
fft = np.fft.fft(spectrum, n=2 * n_wavelengths)
acf_full = np.fft.ifft(fft * np.conj(fft)).real
acf = acf_full[:max_lag] / acf_full[0] # Normalize
# Find where ACF decays to 1/e
threshold = 1.0 / np.e
below_threshold = np.where(acf < threshold)[0]
if len(below_threshold) > 0:
correlation_lengths[i] = below_threshold[0]
else:
correlation_lengths[i] = max_lag
return correlation_lengths
[docs]
def compute_derivative_statistics(
spectra: np.ndarray,
wavelengths: Optional[np.ndarray] = None,
order: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute derivative statistics for spectra.
Args:
spectra: Array of shape (n_samples, n_wavelengths).
wavelengths: Wavelength array for proper derivative scaling.
order: Derivative order (1 or 2).
Returns:
Tuple of (mean_derivatives, std_derivatives) per sample.
Example:
>>> X = np.random.randn(100, 500)
>>> means, stds = compute_derivative_statistics(X, order=1)
"""
if order == 1:
if wavelengths is not None:
dx = np.diff(wavelengths).mean()
derivatives = np.gradient(spectra, dx, axis=1)
else:
derivatives = np.gradient(spectra, axis=1)
elif order == 2:
if wavelengths is not None:
dx = np.diff(wavelengths).mean()
first_deriv = np.gradient(spectra, dx, axis=1)
derivatives = np.gradient(first_deriv, dx, axis=1)
else:
first_deriv = np.gradient(spectra, axis=1)
derivatives = np.gradient(first_deriv, axis=1)
else:
raise ValueError(f"Order must be 1 or 2, got {order}")
means = derivatives.mean(axis=1)
stds = derivatives.std(axis=1)
return means, stds
[docs]
def compute_peak_density(
spectra: np.ndarray,
wavelengths: np.ndarray,
window_nm: float = 100.0,
prominence_threshold: float = 0.01,
) -> np.ndarray:
"""
Compute peak density (peaks per 100 nm) for spectra.
Args:
spectra: Array of shape (n_samples, n_wavelengths).
wavelengths: Wavelength array in nm.
window_nm: Window size for density calculation (default 100 nm).
prominence_threshold: Minimum peak prominence as fraction of spectrum range.
Returns:
Array of peak densities (peaks per window_nm) for each spectrum.
Example:
>>> X = np.random.randn(100, 500)
>>> wl = np.linspace(1000, 2500, 500)
>>> densities = compute_peak_density(X, wl)
"""
n_samples = spectra.shape[0]
wavelength_range = wavelengths.max() - wavelengths.min()
peak_densities = np.zeros(n_samples)
for i in range(n_samples):
spectrum = spectra[i]
prominence = prominence_threshold * (spectrum.max() - spectrum.min())
# Find peaks with minimum prominence
peaks, properties = signal.find_peaks(spectrum, prominence=prominence)
n_peaks = len(peaks)
# Normalize to peaks per window_nm
peak_densities[i] = n_peaks * (window_nm / wavelength_range)
return peak_densities
[docs]
def compute_baseline_curvature(
spectra: np.ndarray,
polynomial_degree: int = 3,
) -> np.ndarray:
"""
Compute baseline curvature by fitting polynomials and measuring residuals.
Args:
spectra: Array of shape (n_samples, n_wavelengths).
polynomial_degree: Degree of polynomial to fit.
Returns:
Array of residual standard deviations for each spectrum.
Example:
>>> X = np.random.randn(100, 500)
>>> curvatures = compute_baseline_curvature(X)
"""
n_samples, n_wavelengths = spectra.shape
x = np.arange(n_wavelengths)
curvatures = np.zeros(n_samples)
for i in range(n_samples):
spectrum = spectra[i]
# Fit polynomial
coeffs = np.polyfit(x, spectrum, polynomial_degree)
fitted = np.polyval(coeffs, x)
# Compute residual std as curvature measure
residuals = spectrum - fitted
curvatures[i] = np.std(residuals)
return curvatures
[docs]
def compute_snr(
spectra: np.ndarray,
noise_region_fraction: float = 0.1,
) -> np.ndarray:
"""
Estimate signal-to-noise ratio for spectra.
Uses the standard deviation of the highest-frequency components
(via high-pass filtering) as noise estimate.
Args:
spectra: Array of shape (n_samples, n_wavelengths).
noise_region_fraction: Fraction of spectrum to use for noise estimation.
Returns:
Array of SNR estimates for each spectrum.
Example:
>>> X = np.random.randn(100, 500) + np.sin(np.linspace(0, 10, 500))
>>> snr = compute_snr(X)
"""
n_samples, n_wavelengths = spectra.shape
snr_values = np.zeros(n_samples)
for i in range(n_samples):
spectrum = spectra[i]
# Signal power: variance of smoothed spectrum
smoothed = gaussian_filter1d(spectrum, sigma=5)
signal_power = np.var(smoothed)
# Noise power: variance of residual after smoothing
residual = spectrum - smoothed
noise_power = np.var(residual)
if noise_power > 1e-10:
snr_values[i] = signal_power / noise_power
else:
snr_values[i] = 1e6 # Very high SNR (essentially noise-free)
return snr_values
[docs]
def compute_distribution_overlap(
dist1: np.ndarray,
dist2: np.ndarray,
n_bins: int = 50,
) -> float:
"""
Compute overlap between two distributions using histogram intersection.
Args:
dist1: First distribution samples.
dist2: Second distribution samples.
n_bins: Number of histogram bins.
Returns:
Overlap coefficient in [0, 1], where 1 means identical distributions.
Example:
>>> x1 = np.random.randn(1000)
>>> x2 = np.random.randn(1000) + 0.5
>>> overlap = compute_distribution_overlap(x1, x2)
"""
# Handle edge cases
if len(dist1) == 0 or len(dist2) == 0:
return 0.0
# Check for constant values or NaN/Inf
if np.std(dist1) < 1e-10 and np.std(dist2) < 1e-10:
# Both are essentially constant - compare means
if np.abs(np.mean(dist1) - np.mean(dist2)) < 1e-10:
return 1.0
else:
return 0.0
# Filter out any NaN/Inf values
dist1 = dist1[np.isfinite(dist1)]
dist2 = dist2[np.isfinite(dist2)]
if len(dist1) == 0 or len(dist2) == 0:
return 0.0
# Determine common bin edges
all_values = np.concatenate([dist1, dist2])
min_val, max_val = all_values.min(), all_values.max()
# Handle case where all values are the same
if max_val - min_val < 1e-10:
return 1.0
bins = np.linspace(min_val, max_val, n_bins + 1)
# Compute normalized histograms
hist1, _ = np.histogram(dist1, bins=bins, density=True)
hist2, _ = np.histogram(dist2, bins=bins, density=True)
# Normalize to sum to 1
hist1 = hist1 / hist1.sum() if hist1.sum() > 0 else hist1
hist2 = hist2 / hist2.sum() if hist2.sum() > 0 else hist2
# Compute intersection (overlap)
overlap = np.minimum(hist1, hist2).sum()
return float(overlap)
[docs]
def compute_adversarial_validation_auc(
real_spectra: np.ndarray,
synthetic_spectra: np.ndarray,
cv_folds: int = 5,
random_state: Optional[int] = None,
) -> Tuple[float, float]:
"""
Train classifier to distinguish real vs. synthetic spectra.
A lower AUC indicates that synthetic data is more realistic
(harder to distinguish from real data).
Args:
real_spectra: Real spectra array (n_real, n_wavelengths).
synthetic_spectra: Synthetic spectra array (n_synthetic, n_wavelengths).
cv_folds: Number of cross-validation folds.
random_state: Random state for reproducibility.
Returns:
Tuple of (mean_auc, std_auc) across folds.
Target:
AUC < 0.6: Excellent (nearly indistinguishable)
AUC < 0.7: Good (hard to distinguish)
AUC < 0.8: Acceptable (some differences)
AUC >= 0.8: Poor (clearly distinguishable)
Example:
>>> real = np.random.randn(100, 500)
>>> synthetic = np.random.randn(100, 500) + 0.1
>>> mean_auc, std_auc = compute_adversarial_validation_auc(real, synthetic)
>>> print(f"AUC: {mean_auc:.3f} ± {std_auc:.3f}")
"""
# Lazy import to avoid sklearn dependency at module level
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
# Prepare data
X = np.vstack([real_spectra, synthetic_spectra])
y = np.concatenate([
np.ones(len(real_spectra)), # Real = 1
np.zeros(len(synthetic_spectra)) # Synthetic = 0
])
# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Train logistic regression with regularization
clf = LogisticRegression(
max_iter=1000,
C=0.1, # Regularization to prevent overfitting
random_state=random_state,
solver='lbfgs',
)
# Cross-validated AUC
auc_scores = cross_val_score(
clf, X_scaled, y,
cv=cv_folds,
scoring='roc_auc'
)
return float(auc_scores.mean()), float(auc_scores.std())
[docs]
def compute_spectral_realism_scorecard(
real_spectra: np.ndarray,
synthetic_spectra: np.ndarray,
wavelengths: Optional[np.ndarray] = None,
thresholds: Optional[Dict[str, float]] = None,
include_adversarial: bool = True,
random_state: Optional[int] = None,
) -> SpectralRealismScore:
"""
Compute comprehensive spectral realism scorecard.
This function computes multiple quantitative metrics to assess
whether synthetic spectra are realistic compared to real data.
Args:
real_spectra: Real spectra array (n_real, n_wavelengths).
synthetic_spectra: Synthetic spectra array (n_synthetic, n_wavelengths).
wavelengths: Wavelength array in nm. If None, uses indices.
thresholds: Custom thresholds for metrics. Defaults:
- correlation_length_overlap: 0.7
- derivative_ks_pvalue: 0.05
- peak_density_ratio_min: 0.5
- peak_density_ratio_max: 2.0
- baseline_curvature_overlap: 0.6
- snr_order_of_magnitude: 1.0 (log10 difference)
- adversarial_auc: 0.7
include_adversarial: Whether to compute adversarial AUC (slower).
random_state: Random state for adversarial validation.
Returns:
SpectralRealismScore with all metrics and pass/fail status.
Example:
>>> from nirs4all.data.synthetic import SyntheticNIRSGenerator
>>> gen = SyntheticNIRSGenerator(random_state=42)
>>> X_synth, _, _ = gen.generate(200)
>>> # X_real would be loaded from real data
>>> X_real = np.random.randn(200, X_synth.shape[1]) # Placeholder
>>> score = compute_spectral_realism_scorecard(X_real, X_synth, gen.wavelengths)
>>> print(score.summary())
"""
# Default thresholds
default_thresholds = {
"correlation_length_overlap": 0.7,
"derivative_ks_pvalue": 0.05,
"peak_density_ratio_min": 0.5,
"peak_density_ratio_max": 2.0,
"baseline_curvature_overlap": 0.6,
"snr_order_of_magnitude": 1.0,
"adversarial_auc": 0.7,
}
if thresholds is not None:
default_thresholds.update(thresholds)
thresholds = default_thresholds
# Create wavelengths if not provided
if wavelengths is None:
n_wavelengths = real_spectra.shape[1]
wavelengths = np.arange(n_wavelengths, dtype=float)
metric_results: List[MetricResult] = []
warnings: List[str] = []
# Validate input dimensions
if real_spectra.shape[1] != synthetic_spectra.shape[1]:
raise ValidationError(
f"Spectral dimension mismatch: real={real_spectra.shape[1]}, "
f"synthetic={synthetic_spectra.shape[1]}"
)
# 1. Correlation Length Overlap
try:
real_corr_len = compute_correlation_length(real_spectra)
synth_corr_len = compute_correlation_length(synthetic_spectra)
corr_overlap = compute_distribution_overlap(real_corr_len, synth_corr_len)
corr_passed = corr_overlap >= thresholds["correlation_length_overlap"]
metric_results.append(MetricResult(
metric=RealismMetric.CORRELATION_LENGTH,
value=corr_overlap,
threshold=thresholds["correlation_length_overlap"],
passed=corr_passed,
details={
"real_mean": float(real_corr_len.mean()),
"real_std": float(real_corr_len.std()),
"synthetic_mean": float(synth_corr_len.mean()),
"synthetic_std": float(synth_corr_len.std()),
}
))
except Exception as e:
warnings.append(f"Correlation length computation failed: {e}")
corr_overlap = 0.0
corr_passed = False
# 2. Derivative Statistics (KS test)
try:
real_deriv_means, real_deriv_stds = compute_derivative_statistics(
real_spectra, wavelengths, order=1
)
synth_deriv_means, synth_deriv_stds = compute_derivative_statistics(
synthetic_spectra, wavelengths, order=1
)
# KS test on derivative standard deviations
ks_stat, ks_pvalue = ks_2samp(real_deriv_stds, synth_deriv_stds)
deriv_passed = ks_pvalue >= thresholds["derivative_ks_pvalue"]
metric_results.append(MetricResult(
metric=RealismMetric.DERIVATIVE_STATISTICS,
value=ks_pvalue,
threshold=thresholds["derivative_ks_pvalue"],
passed=deriv_passed,
details={
"ks_statistic": float(ks_stat),
"real_mean_std": float(real_deriv_stds.mean()),
"synthetic_mean_std": float(synth_deriv_stds.mean()),
}
))
except Exception as e:
warnings.append(f"Derivative statistics computation failed: {e}")
ks_pvalue = 0.0
deriv_passed = False
# 3. Peak Density Ratio
try:
real_peak_density = compute_peak_density(real_spectra, wavelengths)
synth_peak_density = compute_peak_density(synthetic_spectra, wavelengths)
# Ratio of means
real_mean = real_peak_density.mean()
synth_mean = synth_peak_density.mean()
if real_mean > 0:
peak_ratio = synth_mean / real_mean
else:
peak_ratio = 1.0 if synth_mean == 0 else float('inf')
peak_passed = (
thresholds["peak_density_ratio_min"] <= peak_ratio <=
thresholds["peak_density_ratio_max"]
)
metric_results.append(MetricResult(
metric=RealismMetric.PEAK_DENSITY,
value=peak_ratio,
threshold=thresholds["peak_density_ratio_max"],
passed=peak_passed,
details={
"real_mean": float(real_mean),
"synthetic_mean": float(synth_mean),
"ratio": float(peak_ratio),
}
))
except Exception as e:
warnings.append(f"Peak density computation failed: {e}")
peak_ratio = 0.0
peak_passed = False
# 4. Baseline Curvature Overlap
try:
real_curvature = compute_baseline_curvature(real_spectra)
synth_curvature = compute_baseline_curvature(synthetic_spectra)
curvature_overlap = compute_distribution_overlap(real_curvature, synth_curvature)
curvature_passed = curvature_overlap >= thresholds["baseline_curvature_overlap"]
metric_results.append(MetricResult(
metric=RealismMetric.BASELINE_CURVATURE,
value=curvature_overlap,
threshold=thresholds["baseline_curvature_overlap"],
passed=curvature_passed,
details={
"real_mean": float(real_curvature.mean()),
"synthetic_mean": float(synth_curvature.mean()),
}
))
except Exception as e:
warnings.append(f"Baseline curvature computation failed: {e}")
curvature_overlap = 0.0
curvature_passed = False
# 5. SNR Distribution
try:
real_snr = compute_snr(real_spectra)
synth_snr = compute_snr(synthetic_spectra)
# Compare in log scale
real_log_snr = np.log10(real_snr + 1e-10)
synth_log_snr = np.log10(synth_snr + 1e-10)
log_snr_diff = np.abs(real_log_snr.mean() - synth_log_snr.mean())
snr_match = log_snr_diff <= thresholds["snr_order_of_magnitude"]
metric_results.append(MetricResult(
metric=RealismMetric.SNR_DISTRIBUTION,
value=log_snr_diff,
threshold=thresholds["snr_order_of_magnitude"],
passed=snr_match,
details={
"real_mean_snr": float(real_snr.mean()),
"synthetic_mean_snr": float(synth_snr.mean()),
"log_difference": float(log_snr_diff),
}
))
except Exception as e:
warnings.append(f"SNR computation failed: {e}")
snr_match = True # Default to pass on failure
log_snr_diff = 0.0
# 6. Adversarial Validation AUC
adversarial_auc = 0.5 # Default (random guess)
adversarial_passed = True
if include_adversarial:
try:
mean_auc, std_auc = compute_adversarial_validation_auc(
real_spectra, synthetic_spectra,
cv_folds=5,
random_state=random_state,
)
adversarial_auc = mean_auc
adversarial_passed = adversarial_auc <= thresholds["adversarial_auc"]
metric_results.append(MetricResult(
metric=RealismMetric.ADVERSARIAL_AUC,
value=adversarial_auc,
threshold=thresholds["adversarial_auc"],
passed=adversarial_passed,
details={
"mean_auc": float(mean_auc),
"std_auc": float(std_auc),
"target": "lower is better",
}
))
except Exception as e:
warnings.append(f"Adversarial validation failed: {e}")
# Compute overall pass
# All metrics must pass for overall pass
overall_pass = all(result.passed for result in metric_results)
return SpectralRealismScore(
correlation_length_overlap=corr_overlap,
derivative_ks_pvalue=ks_pvalue,
peak_density_ratio=peak_ratio,
baseline_curvature_overlap=curvature_overlap,
snr_magnitude_match=snr_match,
adversarial_auc=adversarial_auc,
overall_pass=overall_pass,
metric_results=metric_results,
warnings=warnings,
)
# ============================================================================
# Benchmark Dataset Comparison Utilities
# ============================================================================
[docs]
@dataclass
class DatasetComparisonResult:
"""
Result of comparing synthetic data against a benchmark dataset.
Attributes:
dataset_name: Name of the benchmark dataset.
n_real_samples: Number of samples in real dataset.
n_synthetic_samples: Number of synthetic samples used.
realism_score: The spectral realism score.
tstr_r2: Train-on-Synthetic, Test-on-Real R² (if applicable).
trts_r2: Train-on-Real, Test-on-Synthetic R² (if applicable).
"""
dataset_name: str
n_real_samples: int
n_synthetic_samples: int
realism_score: SpectralRealismScore
tstr_r2: Optional[float] = None
trts_r2: Optional[float] = None
[docs]
def summary(self) -> str:
"""Return a human-readable summary."""
lines = [
f"Dataset: {self.dataset_name}",
f"Samples: {self.n_real_samples} real, {self.n_synthetic_samples} synthetic",
"",
self.realism_score.summary(),
]
if self.tstr_r2 is not None:
lines.append(f"\nTSTR R²: {self.tstr_r2:.4f}")
if self.trts_r2 is not None:
lines.append(f"TRTS R²: {self.trts_r2:.4f}")
return "\n".join(lines)
[docs]
def validate_against_benchmark(
synthetic_spectra: np.ndarray,
benchmark_spectra: np.ndarray,
benchmark_name: str,
wavelengths: Optional[np.ndarray] = None,
synthetic_targets: Optional[np.ndarray] = None,
benchmark_targets: Optional[np.ndarray] = None,
random_state: Optional[int] = None,
) -> DatasetComparisonResult:
"""
Validate synthetic data against a benchmark dataset.
Args:
synthetic_spectra: Synthetic spectra (n_synth, n_wavelengths).
benchmark_spectra: Real benchmark spectra (n_bench, n_wavelengths).
benchmark_name: Name of the benchmark dataset.
wavelengths: Wavelength array.
synthetic_targets: Optional targets for TSTR/TRTS evaluation.
benchmark_targets: Optional targets for TSTR/TRTS evaluation.
random_state: Random state for reproducibility.
Returns:
DatasetComparisonResult with realism score and optional TSTR/TRTS.
Example:
>>> result = validate_against_benchmark(
... synthetic_spectra=X_synth,
... benchmark_spectra=X_real,
... benchmark_name="Corn",
... )
>>> print(result.summary())
"""
# Compute realism score
realism_score = compute_spectral_realism_scorecard(
real_spectra=benchmark_spectra,
synthetic_spectra=synthetic_spectra,
wavelengths=wavelengths,
random_state=random_state,
)
tstr_r2 = None
trts_r2 = None
# Compute TSTR/TRTS if targets provided
if synthetic_targets is not None and benchmark_targets is not None:
try:
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import r2_score
# Ensure proper shapes
if synthetic_targets.ndim == 1:
synthetic_targets = synthetic_targets.reshape(-1, 1)
if benchmark_targets.ndim == 1:
benchmark_targets = benchmark_targets.reshape(-1, 1)
# TSTR: Train on Synthetic, Test on Real
n_components = min(10, synthetic_spectra.shape[1] // 10, len(synthetic_spectra) // 2)
n_components = max(1, n_components)
pls = PLSRegression(n_components=n_components)
pls.fit(synthetic_spectra, synthetic_targets)
pred_real = pls.predict(benchmark_spectra)
tstr_r2 = float(r2_score(benchmark_targets, pred_real))
# TRTS: Train on Real, Test on Synthetic
pls = PLSRegression(n_components=n_components)
pls.fit(benchmark_spectra, benchmark_targets)
pred_synth = pls.predict(synthetic_spectra)
trts_r2 = float(r2_score(synthetic_targets, pred_synth))
except Exception:
pass # TSTR/TRTS evaluation failed, leave as None
return DatasetComparisonResult(
dataset_name=benchmark_name,
n_real_samples=len(benchmark_spectra),
n_synthetic_samples=len(synthetic_spectra),
realism_score=realism_score,
tstr_r2=tstr_r2,
trts_r2=trts_r2,
)
# ============================================================================
# Quick Validation Functions
# ============================================================================
[docs]
def quick_realism_check(
synthetic_spectra: np.ndarray,
wavelengths: Optional[np.ndarray] = None,
expected_snr_range: Tuple[float, float] = (10, 1000),
expected_peak_density: Tuple[float, float] = (0.5, 10.0),
) -> Tuple[bool, List[str]]:
"""
Perform quick realism checks on synthetic spectra without real data.
This function checks basic properties that realistic spectra should have,
without requiring a reference real dataset.
Args:
synthetic_spectra: Synthetic spectra to check.
wavelengths: Wavelength array.
expected_snr_range: Expected SNR range (min, max).
expected_peak_density: Expected peak density range (peaks per 100 nm).
Returns:
Tuple of (passed, list_of_issues).
Example:
>>> X = generator.generate(100)[0]
>>> passed, issues = quick_realism_check(X, wavelengths)
>>> if not passed:
... print("Issues:", issues)
"""
issues: List[str] = []
# Check for NaN/Inf
if np.any(np.isnan(synthetic_spectra)):
issues.append("Spectra contain NaN values")
if np.any(np.isinf(synthetic_spectra)):
issues.append("Spectra contain Inf values")
# Check SNR
try:
snr = compute_snr(synthetic_spectra)
mean_snr = snr.mean()
if mean_snr < expected_snr_range[0]:
issues.append(f"SNR too low: {mean_snr:.1f} < {expected_snr_range[0]}")
if mean_snr > expected_snr_range[1]:
issues.append(f"SNR unrealistically high: {mean_snr:.1f} > {expected_snr_range[1]}")
except Exception as e:
issues.append(f"SNR check failed: {e}")
# Check peak density
if wavelengths is not None:
try:
peak_densities = compute_peak_density(synthetic_spectra, wavelengths)
mean_density = peak_densities.mean()
if mean_density < expected_peak_density[0]:
issues.append(f"Peak density too low: {mean_density:.2f}")
if mean_density > expected_peak_density[1]:
issues.append(f"Peak density too high: {mean_density:.2f}")
except Exception as e:
issues.append(f"Peak density check failed: {e}")
# Check variance structure
try:
# Spectra should have wavelength-dependent variance (not flat)
wavelength_variance = np.var(synthetic_spectra, axis=0)
cv = np.std(wavelength_variance) / (np.mean(wavelength_variance) + 1e-10)
if cv < 0.1:
issues.append("Variance across wavelengths too uniform (unrealistic)")
except Exception as e:
issues.append(f"Variance check failed: {e}")
passed = len(issues) == 0
return passed, issues