"""
Conditional prior sampling for synthetic NIRS data generation.
This module provides structured prior sampling where configuration
parameters are sampled conditionally based on domain, instrument type,
and other hierarchical dependencies.
Phase 4 Features:
- Domain-weighted sampling
- Conditional instrument selection given domain
- Conditional measurement mode given instrument
- Matrix type conditioning on domain
- Component set selection based on domain
- Full configuration sampling from prior
Generative DAG:
Domain → Instrument Category → Wavelength Range, Resolution, Mode, Noise
→ Matrix Type → Particle Size, Scattering, Water Activity
→ Component Set → Concentration Distributions
→ Target Type
References:
- Workman Jr, J., & Weyer, L. (2012). Practical Guide and Spectral Atlas
for Interpretive Near-Infrared Spectroscopy. CRC Press.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from .domains import (
DomainCategory,
APPLICATION_DOMAINS,
get_domain_config,
list_domains,
)
from .instruments import (
InstrumentCategory,
INSTRUMENT_ARCHETYPES,
get_instrument_archetype,
list_instrument_archetypes,
)
from .measurement_modes import MeasurementMode
# ============================================================================
# Matrix Types
# ============================================================================
[docs]
class MatrixType(str, Enum):
"""Physical matrix types that affect spectral properties."""
LIQUID = "liquid"
POWDER = "powder"
SOLID = "solid"
PASTE = "paste"
EMULSION = "emulsion"
GEL = "gel"
TISSUE = "tissue"
SLURRY = "slurry"
FILM = "film"
GRANULAR = "granular"
# ============================================================================
# Prior Configuration
# ============================================================================
[docs]
@dataclass
class NIRSPriorConfig:
"""
Configuration for NIRS data generation with conditional sampling.
This class defines the prior distributions and conditional dependencies
for sampling complete generation configurations.
Attributes:
domain_weights: Prior weights for each domain.
instrument_given_domain: P(instrument_category | domain).
mode_given_category: P(measurement_mode | instrument_category).
matrix_given_domain: P(matrix_type | domain).
temperature_range: (min, max) temperature in Celsius.
particle_size_range: (min, max) particle size in microns.
noise_level_range: (min, max) noise level multiplier.
Example:
>>> config = NIRSPriorConfig()
>>> sampler = PriorSampler(config, random_state=42)
>>> sample = sampler.sample()
>>> print(sample["domain"], sample["instrument"])
"""
# Domain prior weights
domain_weights: Dict[str, float] = field(default_factory=lambda: {
"grain": 0.15,
"forage": 0.08,
"oilseeds": 0.07,
"fruit": 0.05,
"dairy": 0.10,
"meat": 0.05,
"beverages": 0.05,
"baking": 0.03,
"tablets": 0.10,
"powders": 0.05,
"liquids": 0.05,
"fuel": 0.05,
"polymers": 0.04,
"lubricants": 0.02,
"water_quality": 0.03,
"soil": 0.03,
"tissue": 0.02,
"blood": 0.02,
"textiles": 0.01,
})
# P(instrument_category | domain)
instrument_given_domain: Dict[str, Dict[str, float]] = field(default_factory=lambda: {
# Agriculture domains prefer robust instruments
"grain": {
"benchtop": 0.5, "handheld": 0.2, "process": 0.2,
"embedded": 0.05, "filter": 0.05
},
"forage": {
"benchtop": 0.3, "handheld": 0.4, "process": 0.2,
"embedded": 0.05, "filter": 0.05
},
"oilseeds": {
"benchtop": 0.5, "handheld": 0.1, "process": 0.3,
"embedded": 0.05, "filter": 0.05
},
"fruit": {
"benchtop": 0.3, "handheld": 0.5, "process": 0.1,
"embedded": 0.05, "filter": 0.05
},
# Food domains
"dairy": {
"benchtop": 0.4, "process": 0.4, "handheld": 0.1,
"ft_nir": 0.1
},
"meat": {
"benchtop": 0.4, "handheld": 0.3, "process": 0.2,
"filter": 0.1
},
"beverages": {
"benchtop": 0.3, "process": 0.4, "handheld": 0.2,
"ft_nir": 0.1
},
"baking": {
"benchtop": 0.5, "process": 0.3, "handheld": 0.1,
"filter": 0.1
},
# Pharmaceutical domains prefer high-precision
"tablets": {
"benchtop": 0.5, "ft_nir": 0.3, "process": 0.1,
"handheld": 0.1
},
"powders": {
"benchtop": 0.4, "ft_nir": 0.4, "process": 0.1,
"handheld": 0.1
},
"liquids": {
"benchtop": 0.4, "ft_nir": 0.3, "process": 0.2,
"diode_array": 0.1
},
# Petrochemical
"fuel": {
"benchtop": 0.3, "process": 0.5, "ft_nir": 0.1,
"handheld": 0.1
},
"polymers": {
"benchtop": 0.4, "ft_nir": 0.3, "process": 0.2,
"handheld": 0.1
},
"lubricants": {
"benchtop": 0.3, "process": 0.5, "ft_nir": 0.1,
"handheld": 0.1
},
# Environmental
"water_quality": {
"benchtop": 0.3, "handheld": 0.4, "process": 0.2,
"embedded": 0.1
},
"soil": {
"benchtop": 0.3, "handheld": 0.5, "process": 0.1,
"embedded": 0.1
},
# Biomedical
"tissue": {
"benchtop": 0.4, "ft_nir": 0.3, "handheld": 0.2,
"embedded": 0.1
},
"blood": {
"benchtop": 0.5, "ft_nir": 0.3, "process": 0.1,
"embedded": 0.1
},
# Industrial
"textiles": {
"benchtop": 0.4, "process": 0.3, "handheld": 0.2,
"filter": 0.1
},
})
# P(measurement_mode | instrument_category)
mode_given_category: Dict[str, Dict[str, float]] = field(default_factory=lambda: {
"benchtop": {
"reflectance": 0.5, "transmittance": 0.3,
"transflectance": 0.15, "atr": 0.05
},
"handheld": {
"reflectance": 0.7, "transmittance": 0.1,
"transflectance": 0.15, "atr": 0.05
},
"process": {
"reflectance": 0.4, "transmittance": 0.3,
"transflectance": 0.25, "atr": 0.05
},
"embedded": {
"reflectance": 0.6, "transmittance": 0.2,
"transflectance": 0.15, "atr": 0.05
},
"ft_nir": {
"reflectance": 0.4, "transmittance": 0.35,
"transflectance": 0.1, "atr": 0.15
},
"filter": {
"reflectance": 0.6, "transmittance": 0.3,
"transflectance": 0.1, "atr": 0.0
},
"diode_array": {
"reflectance": 0.5, "transmittance": 0.4,
"transflectance": 0.1, "atr": 0.0
},
})
# P(matrix_type | domain)
matrix_given_domain: Dict[str, Dict[str, float]] = field(default_factory=lambda: {
"grain": {"granular": 0.7, "powder": 0.2, "solid": 0.1},
"forage": {"solid": 0.5, "powder": 0.3, "granular": 0.2},
"oilseeds": {"granular": 0.6, "solid": 0.3, "powder": 0.1},
"fruit": {"solid": 0.6, "paste": 0.2, "liquid": 0.2},
"dairy": {"liquid": 0.5, "emulsion": 0.3, "powder": 0.2},
"meat": {"solid": 0.6, "paste": 0.3, "emulsion": 0.1},
"beverages": {"liquid": 0.9, "emulsion": 0.1},
"baking": {"powder": 0.5, "paste": 0.3, "solid": 0.2},
"tablets": {"solid": 0.7, "powder": 0.3},
"powders": {"powder": 0.9, "granular": 0.1},
"liquids": {"liquid": 0.9, "gel": 0.1},
"fuel": {"liquid": 0.9, "gel": 0.1},
"polymers": {"solid": 0.6, "film": 0.3, "powder": 0.1},
"lubricants": {"liquid": 0.7, "gel": 0.3},
"water_quality": {"liquid": 1.0},
"soil": {"powder": 0.5, "granular": 0.4, "slurry": 0.1},
"tissue": {"tissue": 0.8, "solid": 0.2},
"blood": {"liquid": 0.9, "gel": 0.1},
"textiles": {"solid": 0.8, "film": 0.2},
})
# Continuous parameter ranges
temperature_range: Tuple[float, float] = (15.0, 40.0)
particle_size_range: Tuple[float, float] = (5.0, 200.0)
noise_level_range: Tuple[float, float] = (0.5, 2.0)
n_samples_range: Tuple[int, int] = (100, 2000)
# Target configuration
target_type_weights: Dict[str, float] = field(default_factory=lambda: {
"regression": 0.7,
"classification": 0.3,
})
n_targets_range: Tuple[int, int] = (1, 5)
n_classes_range: Tuple[int, int] = (2, 5)
[docs]
def get_domain_weight(self, domain: str) -> float:
"""Get prior weight for a domain."""
return self.domain_weights.get(domain, 0.0)
[docs]
def normalize_weights(self, weights: Dict[str, float]) -> Dict[str, float]:
"""Normalize weights to sum to 1."""
total = sum(weights.values())
if total == 0:
return weights
return {k: v / total for k, v in weights.items()}
# ============================================================================
# Prior Sampler
# ============================================================================
[docs]
class PriorSampler:
"""
Sample complete generation configurations from prior distributions.
This class implements hierarchical sampling where lower-level
configurations are conditioned on higher-level choices.
Args:
config: Prior configuration.
random_state: Random state for reproducibility.
Example:
>>> config = NIRSPriorConfig()
>>> sampler = PriorSampler(config, random_state=42)
>>>
>>> # Sample a single configuration
>>> sample = sampler.sample()
>>> print(sample)
>>>
>>> # Sample multiple configurations
>>> samples = sampler.sample_batch(10)
"""
def __init__(
self,
config: Optional[NIRSPriorConfig] = None,
random_state: Optional[int] = None,
):
self.config = config or NIRSPriorConfig()
self.rng = np.random.default_rng(random_state)
def _sample_categorical(
self,
weights: Dict[str, float],
) -> str:
"""Sample from a categorical distribution defined by weights."""
# Normalize weights
total = sum(weights.values())
if total == 0:
# Uniform if all zero
categories = list(weights.keys())
return self.rng.choice(categories)
categories = list(weights.keys())
probs = np.array([weights[c] / total for c in categories])
idx = self.rng.choice(len(categories), p=probs)
return categories[idx]
[docs]
def sample_domain(self) -> str:
"""Sample a domain from the prior."""
return self._sample_categorical(self.config.domain_weights)
[docs]
def sample_instrument_category(self, domain: str) -> str:
"""Sample an instrument category given the domain."""
if domain in self.config.instrument_given_domain:
weights = self.config.instrument_given_domain[domain]
else:
# Default uniform over all categories
weights = {cat: 1.0 for cat in [
"benchtop", "handheld", "process", "embedded", "ft_nir"
]}
return self._sample_categorical(weights)
[docs]
def sample_instrument(self, category: str) -> str:
"""Sample a specific instrument given the category."""
# Get all instruments of this category
matching = []
for name, archetype in INSTRUMENT_ARCHETYPES.items():
if archetype.category.value == category:
matching.append(name)
if not matching:
# Fall back to any instrument
matching = list(INSTRUMENT_ARCHETYPES.keys())
return self.rng.choice(matching)
[docs]
def sample_measurement_mode(self, instrument_category: str) -> str:
"""Sample a measurement mode given the instrument category."""
if instrument_category in self.config.mode_given_category:
weights = self.config.mode_given_category[instrument_category]
else:
weights = {
"reflectance": 0.5, "transmittance": 0.3,
"transflectance": 0.15, "atr": 0.05
}
return self._sample_categorical(weights)
[docs]
def sample_matrix_type(self, domain: str) -> str:
"""Sample a matrix type given the domain."""
if domain in self.config.matrix_given_domain:
weights = self.config.matrix_given_domain[domain]
else:
weights = {"powder": 0.3, "liquid": 0.3, "solid": 0.4}
return self._sample_categorical(weights)
[docs]
def sample_temperature(self) -> float:
"""Sample a temperature from the prior range."""
low, high = self.config.temperature_range
return float(self.rng.uniform(low, high))
[docs]
def sample_particle_size(self, matrix_type: str) -> float:
"""Sample particle size based on matrix type."""
low, high = self.config.particle_size_range
# Adjust range based on matrix type
if matrix_type == "powder":
low, high = 5.0, 100.0
elif matrix_type == "granular":
low, high = 50.0, 500.0
elif matrix_type in ("liquid", "emulsion"):
low, high = 0.1, 10.0
elif matrix_type == "solid":
# Not really applicable, but return a value
low, high = 100.0, 1000.0
return float(self.rng.uniform(low, high))
[docs]
def sample_noise_level(self, instrument_category: str) -> float:
"""Sample noise level multiplier based on instrument category."""
low, high = self.config.noise_level_range
# Handheld instruments typically have higher noise
if instrument_category == "handheld":
low, high = 1.0, 3.0
elif instrument_category == "embedded":
low, high = 1.5, 3.5
elif instrument_category == "ft_nir":
low, high = 0.3, 1.0 # FT-NIR typically lower noise
elif instrument_category == "benchtop":
low, high = 0.5, 1.5
return float(self.rng.uniform(low, high))
[docs]
def sample_n_samples(self) -> int:
"""Sample number of samples to generate."""
low, high = self.config.n_samples_range
return int(self.rng.integers(low, high + 1))
[docs]
def sample_target_config(self) -> Dict[str, Any]:
"""Sample target generation configuration."""
target_type = self._sample_categorical(self.config.target_type_weights)
if target_type == "regression":
n_targets = int(self.rng.integers(
self.config.n_targets_range[0],
self.config.n_targets_range[1] + 1
))
return {
"type": "regression",
"n_targets": n_targets,
"nonlinearity": self.rng.choice(["none", "mild", "moderate"]),
}
else:
n_classes = int(self.rng.integers(
self.config.n_classes_range[0],
self.config.n_classes_range[1] + 1
))
return {
"type": "classification",
"n_classes": n_classes,
"separation": self.rng.choice(["easy", "moderate", "hard"]),
}
[docs]
def sample_components(self, domain: str, n_components: Optional[int] = None) -> List[str]:
"""Sample component set based on domain."""
try:
domain_config = get_domain_config(domain)
available = domain_config.typical_components
except Exception:
# Fallback to generic components
available = ["water", "protein", "lipid", "carbohydrate", "cellulose"]
if n_components is None:
n_components = int(self.rng.integers(3, min(8, len(available) + 1)))
n_components = min(n_components, len(available))
return list(self.rng.choice(available, size=n_components, replace=False))
[docs]
def sample(self) -> Dict[str, Any]:
"""
Sample a complete dataset configuration from the prior.
Returns:
Dictionary with all configuration parameters.
Example:
>>> sampler = PriorSampler(random_state=42)
>>> config = sampler.sample()
>>> print(config["domain"])
>>> print(config["instrument"])
"""
# Hierarchical sampling following the DAG
domain = self.sample_domain()
instrument_category = self.sample_instrument_category(domain)
instrument = self.sample_instrument(instrument_category)
measurement_mode = self.sample_measurement_mode(instrument_category)
matrix_type = self.sample_matrix_type(domain)
# Get instrument archetype for wavelength range
archetype = get_instrument_archetype(instrument)
return {
# Domain and application
"domain": domain,
"domain_category": self._get_domain_category(domain),
# Instrument configuration
"instrument": instrument,
"instrument_category": instrument_category,
"wavelength_range": (
archetype.wavelength_range[0],
archetype.wavelength_range[1]
),
"spectral_resolution": archetype.spectral_resolution,
# Measurement configuration
"measurement_mode": measurement_mode,
"matrix_type": matrix_type,
# Environmental conditions
"temperature": self.sample_temperature(),
"particle_size": self.sample_particle_size(matrix_type),
"noise_level": self.sample_noise_level(instrument_category),
# Components
"components": self.sample_components(domain),
# Dataset configuration
"n_samples": self.sample_n_samples(),
"target_config": self.sample_target_config(),
# Metadata
"random_state": int(self.rng.integers(0, 2**31)),
}
def _get_domain_category(self, domain: str) -> str:
"""Get category for a domain."""
try:
domain_config = get_domain_config(domain)
return domain_config.category.value
except Exception:
return "research"
[docs]
def sample_batch(self, n: int) -> List[Dict[str, Any]]:
"""
Sample multiple configurations from the prior.
Args:
n: Number of configurations to sample.
Returns:
List of configuration dictionaries.
"""
return [self.sample() for _ in range(n)]
[docs]
def sample_for_domain(
self,
domain: str,
n_samples: Optional[int] = None,
) -> Dict[str, Any]:
"""
Sample a configuration constrained to a specific domain.
Args:
domain: Domain to sample for.
n_samples: Optional number of samples (uses prior if None).
Returns:
Configuration dictionary for the specified domain.
"""
# Fixed domain, sample rest hierarchically
instrument_category = self.sample_instrument_category(domain)
instrument = self.sample_instrument(instrument_category)
measurement_mode = self.sample_measurement_mode(instrument_category)
matrix_type = self.sample_matrix_type(domain)
archetype = get_instrument_archetype(instrument)
config = {
"domain": domain,
"domain_category": self._get_domain_category(domain),
"instrument": instrument,
"instrument_category": instrument_category,
"wavelength_range": archetype.wavelength_range,
"spectral_resolution": archetype.spectral_resolution,
"measurement_mode": measurement_mode,
"matrix_type": matrix_type,
"temperature": self.sample_temperature(),
"particle_size": self.sample_particle_size(matrix_type),
"noise_level": self.sample_noise_level(instrument_category),
"components": self.sample_components(domain),
"n_samples": n_samples or self.sample_n_samples(),
"target_config": self.sample_target_config(),
"random_state": int(self.rng.integers(0, 2**31)),
}
return config
[docs]
def sample_for_instrument(
self,
instrument: str,
n_samples: Optional[int] = None,
) -> Dict[str, Any]:
"""
Sample a configuration constrained to a specific instrument.
Args:
instrument: Instrument name to use.
n_samples: Optional number of samples.
Returns:
Configuration dictionary for the specified instrument.
"""
archetype = get_instrument_archetype(instrument)
instrument_category = archetype.category.value
# Sample domain that's compatible with this instrument category
# (inverse sampling - find domains where this category is likely)
compatible_domains = []
for domain, cat_weights in self.config.instrument_given_domain.items():
if cat_weights.get(instrument_category, 0) > 0.1:
compatible_domains.append(domain)
if compatible_domains:
domain = self.rng.choice(compatible_domains)
else:
domain = self.sample_domain()
measurement_mode = self.sample_measurement_mode(instrument_category)
matrix_type = self.sample_matrix_type(domain)
return {
"domain": domain,
"domain_category": self._get_domain_category(domain),
"instrument": instrument,
"instrument_category": instrument_category,
"wavelength_range": archetype.wavelength_range,
"spectral_resolution": archetype.spectral_resolution,
"measurement_mode": measurement_mode,
"matrix_type": matrix_type,
"temperature": self.sample_temperature(),
"particle_size": self.sample_particle_size(matrix_type),
"noise_level": self.sample_noise_level(instrument_category),
"components": self.sample_components(domain),
"n_samples": n_samples or self.sample_n_samples(),
"target_config": self.sample_target_config(),
"random_state": int(self.rng.integers(0, 2**31)),
}
# ============================================================================
# Convenience Functions
# ============================================================================
[docs]
def sample_prior(
domain: Optional[str] = None,
instrument: Optional[str] = None,
random_state: Optional[int] = None,
) -> Dict[str, Any]:
"""
Quick function to sample a single configuration from default prior.
Args:
domain: Optional domain constraint.
instrument: Optional instrument constraint.
random_state: Random state for reproducibility.
Returns:
Configuration dictionary.
Example:
>>> config = sample_prior(domain="food", random_state=42)
>>> print(config["domain"], config["instrument"])
"""
sampler = PriorSampler(random_state=random_state)
if domain:
return sampler.sample_for_domain(domain)
elif instrument:
return sampler.sample_for_instrument(instrument)
else:
return sampler.sample()
[docs]
def sample_prior_batch(
n: int,
random_state: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Quick function to sample multiple configurations from default prior.
Args:
n: Number of configurations to sample.
random_state: Random state for reproducibility.
Returns:
List of configuration dictionaries.
Example:
>>> configs = sample_prior_batch(10, random_state=42)
>>> for c in configs:
... print(c["domain"], c["instrument"])
"""
sampler = PriorSampler(random_state=random_state)
return sampler.sample_batch(n)
[docs]
def get_domain_compatible_instruments(domain: str) -> List[str]:
"""
Get list of instruments commonly used with a domain.
Args:
domain: Domain name.
Returns:
List of instrument names.
Example:
>>> instruments = get_domain_compatible_instruments("tablets")
>>> print(instruments)
"""
config = NIRSPriorConfig()
if domain not in config.instrument_given_domain:
return list(INSTRUMENT_ARCHETYPES.keys())
# Get likely instrument categories
cat_weights = config.instrument_given_domain[domain]
likely_categories = [cat for cat, w in cat_weights.items() if w > 0.1]
# Get instruments in those categories
instruments = []
for name, archetype in INSTRUMENT_ARCHETYPES.items():
if archetype.category.value in likely_categories:
instruments.append(name)
return instruments
[docs]
def get_instrument_typical_modes(instrument: str) -> List[str]:
"""
Get typical measurement modes for an instrument.
Args:
instrument: Instrument name.
Returns:
List of measurement mode names.
Example:
>>> modes = get_instrument_typical_modes("viavi_micronir")
>>> print(modes)
"""
config = NIRSPriorConfig()
archetype = get_instrument_archetype(instrument)
category = archetype.category.value
if category not in config.mode_given_category:
return ["reflectance", "transmittance"]
mode_weights = config.mode_given_category[category]
return [mode for mode, w in mode_weights.items() if w > 0.05]