"""
Product-level synthetic NIRS generator for neural network training.
This module provides high-level APIs to generate diverse, realistic product
samples with controlled variability for training neural networks. Unlike
the base SyntheticNIRSGenerator which operates at the component level,
ProductGenerator works with predefined product templates that include
realistic composition variability, component correlations, and bounds.
Key Features:
- Predefined product templates with realistic composition ranges
- Controlled variability types (FIXED, UNIFORM, NORMAL, LOGNORMAL, CORRELATED)
- Composition constraints (sum to 1.0, realistic bounds)
- Correlation preservation between components
- Target flexibility (any component as regression target)
- Efficient batch generation for NN training (10k-100k samples)
- Integration with custom wavelength grids
Example:
>>> from nirs4all.data.synthetic import ProductGenerator, list_product_templates
>>>
>>> # List available templates
>>> print(list_product_templates(category="dairy"))
['milk_variable_fat', 'cheese_variable_moisture']
>>>
>>> # Generate dairy product samples
>>> generator = ProductGenerator("milk_variable_fat")
>>> dataset = generator.generate(n_samples=10000, target="fat")
>>>
>>> # High-variability dataset for NN training
>>> generator = ProductGenerator("food_cholesterol_variable")
>>> dataset = generator.generate(n_samples=50000, target="cholesterol")
References:
[1] USDA FoodData Central (https://fdc.nal.usda.gov/)
[2] Osborne, B. G., Fearn, T., & Hindle, P. H. (1993). Practical NIR Spectroscopy.
[3] Williams, P. (2001). Implementation of Near-Infrared Technology.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
from .components import ComponentLibrary
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
[docs]
class VariationType(Enum):
"""
Type of variation for component concentrations.
Attributes:
FIXED: No variation, use exact specified value.
UNIFORM: Uniform distribution between min and max.
NORMAL: Normal (Gaussian) distribution with mean and std.
LOGNORMAL: Log-normal distribution for non-negative values.
CORRELATED: Value derived from correlation with another component.
COMPUTED: Value computed from other components (e.g., 1 - sum(others)).
"""
FIXED = auto()
UNIFORM = auto()
NORMAL = auto()
LOGNORMAL = auto()
CORRELATED = auto()
COMPUTED = auto()
[docs]
@dataclass
class ComponentVariation:
"""
Specification for how a component's concentration varies.
Attributes:
component: Name of the spectral component (must exist in library).
variation_type: Type of variation (FIXED, UNIFORM, NORMAL, etc.).
value: For FIXED type, the exact value.
min_value: For UNIFORM/NORMAL, the minimum bound.
max_value: For UNIFORM/NORMAL, the maximum bound.
mean: For NORMAL/LOGNORMAL, the distribution mean.
std: For NORMAL/LOGNORMAL, the distribution standard deviation.
correlated_with: For CORRELATED, the source component name.
correlation: For CORRELATED, the correlation coefficient.
compute_as: For COMPUTED, a string describing the computation
(currently supports "remainder" for 1 - sum(others)).
Example:
>>> # Fixed moisture content
>>> moisture = ComponentVariation("moisture", VariationType.FIXED, value=0.12)
>>>
>>> # Variable protein with uniform distribution
>>> protein = ComponentVariation(
... "protein", VariationType.UNIFORM,
... min_value=0.08, max_value=0.18
... )
>>>
>>> # Starch negatively correlated with protein
>>> starch = ComponentVariation(
... "starch", VariationType.CORRELATED,
... correlated_with="protein", correlation=-0.85,
... min_value=0.55, max_value=0.72
... )
"""
component: str
variation_type: VariationType
value: Optional[float] = None
min_value: Optional[float] = None
max_value: Optional[float] = None
mean: Optional[float] = None
std: Optional[float] = None
correlated_with: Optional[str] = None
correlation: Optional[float] = None
compute_as: Optional[str] = None
[docs]
def __post_init__(self) -> None:
"""Validate specification based on variation type."""
vtype = self.variation_type
if vtype == VariationType.FIXED:
if self.value is None:
raise ValueError("FIXED variation requires 'value'")
elif vtype == VariationType.UNIFORM:
if self.min_value is None or self.max_value is None:
raise ValueError("UNIFORM variation requires 'min_value' and 'max_value'")
if self.min_value > self.max_value:
raise ValueError("min_value must be <= max_value")
elif vtype == VariationType.NORMAL:
if self.mean is None or self.std is None:
raise ValueError("NORMAL variation requires 'mean' and 'std'")
if self.std < 0:
raise ValueError("std must be non-negative")
elif vtype == VariationType.LOGNORMAL:
if self.mean is None or self.std is None:
raise ValueError("LOGNORMAL variation requires 'mean' and 'std'")
elif vtype == VariationType.CORRELATED:
if self.correlated_with is None or self.correlation is None:
raise ValueError("CORRELATED variation requires 'correlated_with' and 'correlation'")
if not -1.0 <= self.correlation <= 1.0:
raise ValueError("correlation must be between -1 and 1")
elif vtype == VariationType.COMPUTED:
if self.compute_as is None:
raise ValueError("COMPUTED variation requires 'compute_as'")
[docs]
@dataclass
class ProductTemplate:
"""
Template defining a product type with composition variability.
A ProductTemplate describes a realistic product type (e.g., wheat grain,
milk, pharmaceutical tablet) along with specifications for how each
component's concentration can vary. This enables generation of diverse
samples suitable for neural network training.
Attributes:
name: Unique identifier for the template.
description: Human-readable description.
category: Product category (e.g., "dairy", "grain", "pharma").
domain: Application domain (e.g., "agriculture", "food", "pharmaceutical").
components: List of ComponentVariation specifications.
default_target: Default component to use as regression target.
tags: Classification tags for filtering.
references: Literature or data source citations.
Example:
>>> milk_template = ProductTemplate(
... name="milk_variable_fat",
... description="Milk with variable fat content (skim to whole)",
... category="dairy",
... domain="food",
... components=[
... ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
... ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.005, max_value=0.06),
... ComponentVariation("casein", VariationType.NORMAL, mean=0.028, std=0.003),
... ComponentVariation("whey", VariationType.FIXED, value=0.006),
... ComponentVariation("lactose", VariationType.NORMAL, mean=0.05, std=0.003),
... ],
... default_target="lipid",
... )
"""
name: str
description: str
category: str
domain: str
components: List[ComponentVariation]
default_target: str = ""
tags: List[str] = field(default_factory=list)
references: List[str] = field(default_factory=list)
[docs]
def __post_init__(self) -> None:
"""Validate template consistency."""
comp_names = [c.component for c in self.components]
# Check for duplicates
if len(comp_names) != len(set(comp_names)):
raise ValueError(f"Duplicate components in template '{self.name}'")
# Check correlated components reference valid sources
for comp_var in self.components:
if comp_var.variation_type == VariationType.CORRELATED:
if comp_var.correlated_with not in comp_names:
raise ValueError(
f"Component '{comp_var.component}' correlates with "
f"'{comp_var.correlated_with}' which is not in template"
)
@property
def component_names(self) -> List[str]:
"""Return list of component names in this template."""
return [c.component for c in self.components]
[docs]
def info(self) -> str:
"""Return formatted information about the template."""
lines = [
f"ProductTemplate: {self.name}",
f"Description: {self.description}",
f"Category: {self.category}",
f"Domain: {self.domain}",
f"Default Target: {self.default_target or 'N/A'}",
f"Components ({len(self.components)}):",
]
for comp in self.components:
vtype = comp.variation_type.name
if comp.variation_type == VariationType.FIXED:
detail = f"= {comp.value:.3f}"
elif comp.variation_type == VariationType.UNIFORM:
detail = f"~ U({comp.min_value:.3f}, {comp.max_value:.3f})"
elif comp.variation_type == VariationType.NORMAL:
detail = f"~ N({comp.mean:.3f}, {comp.std:.3f})"
elif comp.variation_type == VariationType.LOGNORMAL:
detail = f"~ LogN({comp.mean:.3f}, {comp.std:.3f})"
elif comp.variation_type == VariationType.CORRELATED:
detail = f"corr={comp.correlation:.2f} with {comp.correlated_with}"
elif comp.variation_type == VariationType.COMPUTED:
detail = f"computed as {comp.compute_as}"
else:
detail = ""
lines.append(f" - {comp.component}: {vtype} {detail}")
if self.tags:
lines.append(f"Tags: {', '.join(self.tags)}")
return "\n".join(lines)
# =============================================================================
# Predefined Product Templates
# =============================================================================
PRODUCT_TEMPLATES: Dict[str, ProductTemplate] = {}
def _register_templates() -> None:
"""Register all predefined product templates."""
global PRODUCT_TEMPLATES
# =========================================================================
# DAIRY
# =========================================================================
PRODUCT_TEMPLATES["milk_variable_fat"] = ProductTemplate(
name="milk_variable_fat",
description="Milk with variable fat content (skim to whole)",
category="dairy",
domain="food",
components=[
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.005, max_value=0.06),
ComponentVariation("casein", VariationType.NORMAL, mean=0.028, std=0.003, min_value=0.022, max_value=0.034),
ComponentVariation("whey", VariationType.NORMAL, mean=0.006, std=0.001, min_value=0.004, max_value=0.008),
ComponentVariation("lactose", VariationType.NORMAL, mean=0.05, std=0.003, min_value=0.045, max_value=0.055),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="lipid",
tags=["dairy", "milk", "liquid", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["cheese_variable_moisture"] = ProductTemplate(
name="cheese_variable_moisture",
description="Cheese with variable moisture (aged to fresh)",
category="dairy",
domain="food",
components=[
ComponentVariation("moisture", VariationType.UNIFORM, min_value=0.28, max_value=0.45),
ComponentVariation("lipid", VariationType.CORRELATED, correlated_with="moisture", correlation=-0.6,
min_value=0.25, max_value=0.40, mean=0.33, std=0.04),
ComponentVariation("casein", VariationType.CORRELATED, correlated_with="moisture", correlation=-0.5,
min_value=0.18, max_value=0.32, mean=0.25, std=0.03),
ComponentVariation("lactose", VariationType.NORMAL, mean=0.02, std=0.01, min_value=0.005, max_value=0.035),
],
default_target="moisture",
tags=["dairy", "cheese", "solid", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["yogurt_variable_fat"] = ProductTemplate(
name="yogurt_variable_fat",
description="Yogurt with variable fat content (non-fat to full-fat)",
category="dairy",
domain="food",
components=[
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.001, max_value=0.10),
ComponentVariation("casein", VariationType.NORMAL, mean=0.035, std=0.005, min_value=0.025, max_value=0.045),
ComponentVariation("whey", VariationType.NORMAL, mean=0.015, std=0.003, min_value=0.010, max_value=0.020),
ComponentVariation("lactose", VariationType.NORMAL, mean=0.04, std=0.005, min_value=0.030, max_value=0.050),
ComponentVariation("lactic_acid", VariationType.UNIFORM, min_value=0.005, max_value=0.015),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="lipid",
tags=["dairy", "yogurt", "fermented", "nn_training"],
references=["USDA FoodData Central"],
)
# =========================================================================
# MEAT
# =========================================================================
PRODUCT_TEMPLATES["meat_variable_fat"] = ProductTemplate(
name="meat_variable_fat",
description="Meat with variable fat content (very lean to marbled)",
category="meat",
domain="food",
components=[
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.01, max_value=0.35),
ComponentVariation("protein", VariationType.CORRELATED, correlated_with="lipid", correlation=-0.6,
min_value=0.14, max_value=0.24, mean=0.20, std=0.03),
ComponentVariation("water", VariationType.CORRELATED, correlated_with="lipid", correlation=-0.8,
min_value=0.45, max_value=0.78, mean=0.70, std=0.08),
ComponentVariation("collagen", VariationType.NORMAL, mean=0.015, std=0.005, min_value=0.005, max_value=0.030),
],
default_target="lipid",
tags=["meat", "protein_source", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["meat_variable_protein"] = ProductTemplate(
name="meat_variable_protein",
description="Meat with variable protein content across cuts",
category="meat",
domain="food",
components=[
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.14, max_value=0.28),
ComponentVariation("lipid", VariationType.CORRELATED, correlated_with="protein", correlation=-0.5,
min_value=0.02, max_value=0.25, mean=0.10, std=0.06),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
ComponentVariation("collagen", VariationType.NORMAL, mean=0.012, std=0.004, min_value=0.005, max_value=0.025),
],
default_target="protein",
tags=["meat", "protein_source", "nn_training"],
references=["USDA FoodData Central"],
)
# =========================================================================
# GRAIN
# =========================================================================
PRODUCT_TEMPLATES["wheat_variable_protein"] = ProductTemplate(
name="wheat_variable_protein",
description="Wheat grain with variable protein (feed to bread wheat)",
category="grain",
domain="agriculture",
components=[
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.08, max_value=0.18),
ComponentVariation("starch", VariationType.CORRELATED, correlated_with="protein", correlation=-0.85,
min_value=0.55, max_value=0.72, mean=0.65, std=0.04),
ComponentVariation("moisture", VariationType.NORMAL, mean=0.12, std=0.02, min_value=0.08, max_value=0.16),
ComponentVariation("lipid", VariationType.NORMAL, mean=0.02, std=0.005, min_value=0.01, max_value=0.03),
ComponentVariation("cellulose", VariationType.NORMAL, mean=0.08, std=0.015, min_value=0.05, max_value=0.12),
],
default_target="protein",
tags=["grain", "cereal", "agriculture", "nn_training"],
references=["USDA FoodData Central", "Osborne1993"],
)
PRODUCT_TEMPLATES["corn_grain"] = ProductTemplate(
name="corn_grain",
description="Corn/maize grain with typical composition variation",
category="grain",
domain="agriculture",
components=[
ComponentVariation("starch", VariationType.NORMAL, mean=0.72, std=0.03, min_value=0.65, max_value=0.78),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.06, max_value=0.12),
ComponentVariation("moisture", VariationType.NORMAL, mean=0.11, std=0.015, min_value=0.08, max_value=0.15),
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.03, max_value=0.06),
ComponentVariation("cellulose", VariationType.NORMAL, mean=0.03, std=0.01, min_value=0.02, max_value=0.05),
],
default_target="protein",
tags=["grain", "cereal", "agriculture", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["soybean"] = ProductTemplate(
name="soybean",
description="Soybean with variable protein and oil content",
category="legume",
domain="agriculture",
components=[
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.32, max_value=0.42),
ComponentVariation("lipid", VariationType.CORRELATED, correlated_with="protein", correlation=-0.4,
min_value=0.16, max_value=0.24, mean=0.20, std=0.02),
ComponentVariation("moisture", VariationType.NORMAL, mean=0.10, std=0.015, min_value=0.07, max_value=0.14),
ComponentVariation("starch", VariationType.FIXED, value=0.05),
ComponentVariation("cellulose", VariationType.NORMAL, mean=0.15, std=0.02, min_value=0.10, max_value=0.20),
ComponentVariation("sucrose", VariationType.NORMAL, mean=0.05, std=0.01, min_value=0.03, max_value=0.07),
],
default_target="protein",
tags=["legume", "oilseed", "agriculture", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["rice_grain"] = ProductTemplate(
name="rice_grain",
description="Polished rice grain composition",
category="grain",
domain="agriculture",
components=[
ComponentVariation("starch", VariationType.NORMAL, mean=0.78, std=0.02, min_value=0.74, max_value=0.82),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.05, max_value=0.10),
ComponentVariation("moisture", VariationType.NORMAL, mean=0.12, std=0.01, min_value=0.10, max_value=0.14),
ComponentVariation("lipid", VariationType.FIXED, value=0.01),
ComponentVariation("cellulose", VariationType.FIXED, value=0.01),
],
default_target="protein",
tags=["grain", "cereal", "agriculture", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["barley_grain"] = ProductTemplate(
name="barley_grain",
description="Barley grain for malting and feed",
category="grain",
domain="agriculture",
components=[
ComponentVariation("starch", VariationType.NORMAL, mean=0.60, std=0.03, min_value=0.54, max_value=0.66),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.08, max_value=0.15),
ComponentVariation("moisture", VariationType.NORMAL, mean=0.11, std=0.01, min_value=0.09, max_value=0.13),
ComponentVariation("lipid", VariationType.FIXED, value=0.02),
ComponentVariation("cellulose", VariationType.NORMAL, mean=0.14, std=0.02, min_value=0.10, max_value=0.18),
],
default_target="protein",
tags=["grain", "cereal", "malting", "agriculture", "nn_training"],
references=["USDA FoodData Central"],
)
# =========================================================================
# PHARMACEUTICAL
# =========================================================================
PRODUCT_TEMPLATES["tablet_variable_api"] = ProductTemplate(
name="tablet_variable_api",
description="Tablet with variable API content (process monitoring)",
category="solid_dosage",
domain="pharmaceutical",
components=[
# API content varies (simulating blend uniformity issues)
ComponentVariation("paracetamol", VariationType.NORMAL, mean=0.50, std=0.05, min_value=0.40, max_value=0.60),
ComponentVariation("microcrystalline_cellulose", VariationType.NORMAL, mean=0.25, std=0.03,
min_value=0.18, max_value=0.32),
ComponentVariation("starch", VariationType.NORMAL, mean=0.15, std=0.02, min_value=0.10, max_value=0.20),
ComponentVariation("moisture", VariationType.UNIFORM, min_value=0.02, max_value=0.08),
],
default_target="paracetamol",
tags=["pharma", "tablet", "api", "process_monitoring", "nn_training"],
references=["Reich2005"],
)
PRODUCT_TEMPLATES["tablet_moisture_stability"] = ProductTemplate(
name="tablet_moisture_stability",
description="Tablet with variable moisture (stability study)",
category="solid_dosage",
domain="pharmaceutical",
components=[
ComponentVariation("moisture", VariationType.UNIFORM, min_value=0.01, max_value=0.12),
ComponentVariation("paracetamol", VariationType.NORMAL, mean=0.50, std=0.02, min_value=0.46, max_value=0.54),
ComponentVariation("microcrystalline_cellulose", VariationType.FIXED, value=0.25),
ComponentVariation("starch", VariationType.FIXED, value=0.15),
],
default_target="moisture",
tags=["pharma", "tablet", "stability", "moisture", "nn_training"],
references=["Reich2005"],
)
PRODUCT_TEMPLATES["capsule_blend_uniformity"] = ProductTemplate(
name="capsule_blend_uniformity",
description="Capsule fill with blend uniformity variation",
category="solid_dosage",
domain="pharmaceutical",
components=[
ComponentVariation("ibuprofen", VariationType.NORMAL, mean=0.40, std=0.04, min_value=0.32, max_value=0.48),
ComponentVariation("starch", VariationType.CORRELATED, correlated_with="ibuprofen", correlation=-0.3,
min_value=0.25, max_value=0.40, mean=0.32, std=0.03),
ComponentVariation("lactose", VariationType.NORMAL, mean=0.20, std=0.02, min_value=0.15, max_value=0.25),
ComponentVariation("moisture", VariationType.UNIFORM, min_value=0.03, max_value=0.08),
],
default_target="ibuprofen",
tags=["pharma", "capsule", "blend_uniformity", "nn_training"],
references=["Reich2005"],
)
# =========================================================================
# HIGH-VARIABILITY NN TRAINING TEMPLATES
# =========================================================================
PRODUCT_TEMPLATES["food_cholesterol_variable"] = ProductTemplate(
name="food_cholesterol_variable",
description="Food matrix with wide cholesterol variability for robust NN training",
category="nn_training",
domain="food",
components=[
# Cholesterol with very wide range (eggs, meats, dairy, processed foods)
ComponentVariation("cholesterol", VariationType.LOGNORMAL, mean=0.01, std=0.015,
min_value=0.0001, max_value=0.05),
ComponentVariation("lipid", VariationType.CORRELATED, correlated_with="cholesterol", correlation=0.7,
min_value=0.01, max_value=0.50, mean=0.15, std=0.12),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.05, max_value=0.30),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="cholesterol",
tags=["food", "cholesterol", "high_variability", "nn_training", "robust"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["universal_fat_predictor"] = ProductTemplate(
name="universal_fat_predictor",
description="Wide fat range across food categories for universal fat NN",
category="nn_training",
domain="food",
components=[
# Fat spanning skim milk to butter
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.001, max_value=0.85),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.01, max_value=0.35),
ComponentVariation("starch", VariationType.UNIFORM, min_value=0.0, max_value=0.40),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="lipid",
tags=["food", "fat", "universal", "high_variability", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["universal_protein_predictor"] = ProductTemplate(
name="universal_protein_predictor",
description="Wide protein range across food/feed for universal protein NN",
category="nn_training",
domain="food",
components=[
# Protein spanning vegetables to pure protein isolates
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.01, max_value=0.95),
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.001, max_value=0.30),
ComponentVariation("starch", VariationType.UNIFORM, min_value=0.0, max_value=0.50),
ComponentVariation("cellulose", VariationType.UNIFORM, min_value=0.0, max_value=0.20),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="protein",
tags=["food", "protein", "universal", "high_variability", "nn_training"],
references=["USDA FoodData Central"],
)
PRODUCT_TEMPLATES["universal_moisture_predictor"] = ProductTemplate(
name="universal_moisture_predictor",
description="Wide moisture range for universal moisture NN",
category="nn_training",
domain="food",
components=[
# Moisture from dried products to liquids
ComponentVariation("water", VariationType.UNIFORM, min_value=0.02, max_value=0.98),
ComponentVariation("protein", VariationType.UNIFORM, min_value=0.01, max_value=0.40),
ComponentVariation("starch", VariationType.UNIFORM, min_value=0.0, max_value=0.80),
ComponentVariation("lipid", VariationType.UNIFORM, min_value=0.0, max_value=0.50),
],
default_target="water",
tags=["food", "moisture", "water", "universal", "high_variability", "nn_training"],
references=["USDA FoodData Central"],
)
# =========================================================================
# FRUITS AND VEGETABLES
# =========================================================================
PRODUCT_TEMPLATES["fruit_sugar_variable"] = ProductTemplate(
name="fruit_sugar_variable",
description="Fruit with variable sugar content (ripeness variation)",
category="fruit",
domain="food",
components=[
ComponentVariation("fructose", VariationType.UNIFORM, min_value=0.02, max_value=0.12),
ComponentVariation("glucose", VariationType.CORRELATED, correlated_with="fructose", correlation=0.9,
min_value=0.01, max_value=0.08, mean=0.04, std=0.02),
ComponentVariation("sucrose", VariationType.UNIFORM, min_value=0.005, max_value=0.08),
ComponentVariation("malic_acid", VariationType.CORRELATED, correlated_with="fructose", correlation=-0.6,
min_value=0.001, max_value=0.015, mean=0.005, std=0.003),
ComponentVariation("water", VariationType.COMPUTED, compute_as="remainder"),
],
default_target="fructose",
tags=["fruit", "sugar", "ripeness", "nn_training"],
references=["USDA FoodData Central"],
)
# Register templates on module load
_register_templates()
# =============================================================================
# ProductGenerator Class
# =============================================================================
[docs]
class ProductGenerator:
"""
Generator for product-level synthetic NIRS spectra.
ProductGenerator creates realistic synthetic spectra based on predefined
product templates with controlled composition variability. It handles
correlation constraints, compositional bounds, and efficient batch
generation for neural network training.
Attributes:
template: The ProductTemplate used for generation.
library: ComponentLibrary with the required spectral components.
rng: NumPy random generator for reproducibility.
Args:
template: Template name (str) or ProductTemplate object.
random_state: Random seed for reproducibility.
wavelength_start: Start wavelength in nm (default: 1000).
wavelength_end: End wavelength in nm (default: 2500).
wavelength_step: Wavelength step in nm (default: 2).
wavelengths: Custom wavelength array (overrides start/end/step).
instrument_wavelength_grid: Predefined instrument grid name.
complexity: Spectral complexity ('simple', 'realistic', 'complex').
Example:
>>> # Generate milk samples with variable fat
>>> generator = ProductGenerator("milk_variable_fat", random_state=42)
>>> dataset = generator.generate(n_samples=1000, target="lipid")
>>>
>>> # High-variability training data
>>> generator = ProductGenerator("universal_protein_predictor")
>>> dataset = generator.generate(n_samples=50000, target="protein")
>>>
>>> # Match specific instrument wavelengths
>>> generator = ProductGenerator(
... "wheat_variable_protein",
... instrument_wavelength_grid="foss_xds"
... )
"""
def __init__(
self,
template: Union[str, ProductTemplate],
random_state: Optional[int] = None,
wavelength_start: float = 1000.0,
wavelength_end: float = 2500.0,
wavelength_step: float = 2.0,
wavelengths: Optional[np.ndarray] = None,
instrument_wavelength_grid: Optional[str] = None,
complexity: str = "realistic",
) -> None:
"""Initialize the product generator."""
# Get template
if isinstance(template, str):
self.template = get_product_template(template)
else:
self.template = template
self._random_state = random_state
self.rng = np.random.default_rng(random_state)
self.complexity = complexity
# Store wavelength config
self._wavelength_start = wavelength_start
self._wavelength_end = wavelength_end
self._wavelength_step = wavelength_step
self._wavelengths = wavelengths
self._instrument_wavelength_grid = instrument_wavelength_grid
# Create component library from template components
self.library = ComponentLibrary.from_predefined(
self.template.component_names,
random_state=random_state,
)
def _sample_compositions(self, n_samples: int) -> np.ndarray:
"""
Sample component compositions respecting variability and correlations.
This method generates realistic concentration matrices by:
1. Sampling independent components according to their variation types
2. Applying correlation constraints
3. Computing "remainder" components
4. Ensuring non-negative values and reasonable totals
Args:
n_samples: Number of composition samples to generate.
Returns:
Concentration matrix of shape (n_samples, n_components).
"""
n_components = len(self.template.components)
comp_names = self.template.component_names
# Initialize result matrix
concentrations = np.zeros((n_samples, n_components))
# Track which components have been sampled
sampled = set()
# First pass: sample independent components
for i, comp_var in enumerate(self.template.components):
vtype = comp_var.variation_type
if vtype == VariationType.FIXED:
concentrations[:, i] = comp_var.value
sampled.add(comp_var.component)
elif vtype == VariationType.UNIFORM:
concentrations[:, i] = self.rng.uniform(
comp_var.min_value, comp_var.max_value, n_samples
)
sampled.add(comp_var.component)
elif vtype == VariationType.NORMAL:
values = self.rng.normal(comp_var.mean, comp_var.std, n_samples)
if comp_var.min_value is not None and comp_var.max_value is not None:
values = np.clip(values, comp_var.min_value, comp_var.max_value)
concentrations[:, i] = values
sampled.add(comp_var.component)
elif vtype == VariationType.LOGNORMAL:
# Convert mean/std to log-space parameters
mean, std = comp_var.mean, comp_var.std
sigma_sq = np.log(1 + (std / mean) ** 2)
mu = np.log(mean) - sigma_sq / 2
sigma = np.sqrt(sigma_sq)
values = self.rng.lognormal(mu, sigma, n_samples)
if comp_var.min_value is not None and comp_var.max_value is not None:
values = np.clip(values, comp_var.min_value, comp_var.max_value)
concentrations[:, i] = values
sampled.add(comp_var.component)
# Second pass: sample correlated components
for i, comp_var in enumerate(self.template.components):
if comp_var.variation_type != VariationType.CORRELATED:
continue
if comp_var.component in sampled:
continue
# Get source component values
source_name = comp_var.correlated_with
source_idx = comp_names.index(source_name)
source_values = concentrations[:, source_idx]
# Get source component variation for normalization
source_comp = None
for cv in self.template.components:
if cv.component == source_name:
source_comp = cv
break
# Normalize source to [0, 1] range
if source_comp is not None:
if source_comp.min_value is not None and source_comp.max_value is not None:
source_range = source_comp.max_value - source_comp.min_value
source_normalized = (source_values - source_comp.min_value) / source_range
else:
source_normalized = source_values / source_values.max()
else:
source_normalized = source_values / source_values.max()
# Generate correlated values
correlation = comp_var.correlation
mean = comp_var.mean if comp_var.mean is not None else 0.5
std = comp_var.std if comp_var.std is not None else 0.1
# Use Cholesky decomposition for correlation
noise = self.rng.normal(0, 1, n_samples)
correlated = correlation * source_normalized + np.sqrt(1 - correlation**2) * noise
# Scale to target range
if comp_var.min_value is not None and comp_var.max_value is not None:
target_range = comp_var.max_value - comp_var.min_value
target_center = (comp_var.min_value + comp_var.max_value) / 2
# Map correlated values to target range
values = target_center + (correlated - 0.5) * target_range
values = np.clip(values, comp_var.min_value, comp_var.max_value)
else:
# Use mean/std
values = mean + correlated * std
values = np.maximum(values, 0)
concentrations[:, i] = values
sampled.add(comp_var.component)
# Third pass: compute "remainder" components
for i, comp_var in enumerate(self.template.components):
if comp_var.variation_type != VariationType.COMPUTED:
continue
if comp_var.compute_as == "remainder":
# Sum all other components
mask = np.ones(n_components, dtype=bool)
mask[i] = False
other_sum = concentrations[:, mask].sum(axis=1)
# Compute remainder
remainder = 1.0 - other_sum
# Clip to reasonable bounds
remainder = np.clip(remainder, 0.01, 0.99)
concentrations[:, i] = remainder
else:
raise ValueError(f"Unknown compute_as: {comp_var.compute_as}")
return concentrations
[docs]
def generate(
self,
n_samples: int = 1000,
target: Optional[str] = None,
train_ratio: float = 0.8,
include_batch_effects: bool = False,
n_batches: int = 1,
return_concentrations: bool = False,
) -> Union["SpectroDataset", Tuple["SpectroDataset", np.ndarray]]:
"""
Generate synthetic product samples.
Args:
n_samples: Number of samples to generate.
target: Component to use as regression target.
If None, uses template's default_target.
train_ratio: Proportion of samples for training partition.
include_batch_effects: Whether to add batch/session effects.
n_batches: Number of batches (if include_batch_effects=True).
return_concentrations: If True, also return the full concentration matrix.
Returns:
SpectroDataset with train/test partitions.
If return_concentrations=True, returns (dataset, concentrations).
Example:
>>> generator = ProductGenerator("milk_variable_fat")
>>> dataset = generator.generate(n_samples=1000, target="lipid")
>>> print(f"Train: {dataset.n_train}, Test: {dataset.n_test}")
"""
from .generator import SyntheticNIRSGenerator
from nirs4all.data.dataset import SpectroDataset
# Determine target component
if target is None:
target = self.template.default_target
if not target:
# Use first component if no default
target = self.template.component_names[0]
# Sample compositions
concentrations = self._sample_compositions(n_samples)
# Create generator with matching library
generator = SyntheticNIRSGenerator(
wavelength_start=self._wavelength_start,
wavelength_end=self._wavelength_end,
wavelength_step=self._wavelength_step,
wavelengths=self._wavelengths,
instrument_wavelength_grid=self._instrument_wavelength_grid,
component_library=self.library,
complexity=self.complexity,
random_state=self._random_state,
)
# Generate spectra from concentrations
X, metadata = generator.generate_from_concentrations(
concentrations,
include_batch_effects=include_batch_effects,
n_batches=n_batches,
)
# Get target values
target_idx = self.template.component_names.index(target)
y = concentrations[:, target_idx]
# Calculate split
n_train = int(n_samples * train_ratio)
# Create dataset
dataset = SpectroDataset(name=f"synthetic_{self.template.name}")
# Create wavelength headers
headers = [str(int(wl)) for wl in generator.wavelengths]
# Add training samples
dataset.add_samples(
X[:n_train],
indexes={"partition": "train"},
headers=headers,
header_unit="nm",
)
dataset.add_targets(y[:n_train])
# Add test samples
dataset.add_samples(
X[n_train:],
indexes={"partition": "test"},
headers=headers,
header_unit="nm",
)
dataset.add_targets(y[n_train:])
if return_concentrations:
return dataset, concentrations
return dataset
[docs]
def generate_dataset_for_target(
self,
target: str,
n_samples: int = 1000,
target_range: Optional[Tuple[float, float]] = None,
**kwargs: Any,
) -> "SpectroDataset":
"""
Generate dataset optimized for a specific target component.
This is a convenience method that generates a dataset and optionally
scales the target values to a specified range.
Args:
target: Component to use as regression target.
n_samples: Number of samples to generate.
target_range: Optional (min, max) to scale target values.
**kwargs: Additional arguments passed to generate().
Returns:
SpectroDataset ready for pipeline use.
Example:
>>> generator = ProductGenerator("wheat_variable_protein")
>>> dataset = generator.generate_dataset_for_target(
... target="protein",
... n_samples=10000,
... target_range=(0, 100) # Scale to percentage
... )
"""
# Generate with return_concentrations to get the target component
dataset, concentrations = self.generate(
n_samples=n_samples,
target=target,
return_concentrations=True,
**kwargs
)
if target_range is not None:
# Get target values from concentrations
target_idx = self.template.component_names.index(target)
y = concentrations[:, target_idx]
y_min, y_max = y.min(), y.max()
if y_max > y_min:
target_min, target_max = target_range
y_scaled = (y - y_min) / (y_max - y_min) * (target_max - target_min) + target_min
# Re-create dataset with scaled targets
from nirs4all.data.dataset import SpectroDataset as DS
from .generator import SyntheticNIRSGenerator
# Get wavelengths from generator
generator = SyntheticNIRSGenerator(
wavelength_start=self._wavelength_start,
wavelength_end=self._wavelength_end,
wavelength_step=self._wavelength_step,
wavelengths=self._wavelengths,
instrument_wavelength_grid=self._instrument_wavelength_grid,
component_library=self.library,
complexity=self.complexity,
random_state=self._random_state,
)
# Calculate split
train_ratio = kwargs.get("train_ratio", 0.8)
n_train = int(n_samples * train_ratio)
# Get X from original dataset
X = dataset.x({}, layout="2d")
# Create new dataset with scaled y
new_dataset = DS(name=f"synthetic_{self.template.name}")
headers = [str(int(wl)) for wl in generator.wavelengths]
# Add training samples
new_dataset.add_samples(
X[:n_train],
indexes={"partition": "train"},
headers=headers,
header_unit="nm",
)
new_dataset.add_targets(y_scaled[:n_train])
# Add test samples
new_dataset.add_samples(
X[n_train:],
indexes={"partition": "test"},
headers=headers,
header_unit="nm",
)
new_dataset.add_targets(y_scaled[n_train:])
return new_dataset
return dataset
[docs]
def __repr__(self) -> str:
"""Return string representation."""
return (
f"ProductGenerator(template='{self.template.name}', "
f"n_components={len(self.template.components)}, "
f"default_target='{self.template.default_target}')"
)
# =============================================================================
# CategoryGenerator Class
# =============================================================================
[docs]
class CategoryGenerator:
"""
Generator combining multiple product templates for diverse datasets.
CategoryGenerator enables creation of training datasets that span
multiple product types, useful for building robust models that
generalize across categories.
Attributes:
templates: List of ProductTemplate objects.
generators: List of ProductGenerator objects for each template.
Args:
templates: List of template names or ProductTemplate objects.
random_state: Random seed for reproducibility.
**kwargs: Additional arguments passed to ProductGenerator.
Example:
>>> # Combine dairy products
>>> gen = CategoryGenerator(["milk_variable_fat", "cheese_variable_moisture"])
>>> dataset = gen.generate(n_samples=2000, target="lipid")
>>>
>>> # Universal fat predictor training
>>> gen = CategoryGenerator([
... "milk_variable_fat",
... "cheese_variable_moisture",
... "meat_variable_fat",
... ])
>>> dataset = gen.generate(n_samples=10000, target="lipid")
"""
def __init__(
self,
templates: List[Union[str, ProductTemplate]],
random_state: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Initialize the category generator."""
self._random_state = random_state
self.rng = np.random.default_rng(random_state)
# Convert template names to ProductTemplate objects
self.templates: List[ProductTemplate] = []
for template in templates:
if isinstance(template, str):
self.templates.append(get_product_template(template))
else:
self.templates.append(template)
# Create generators for each template
# Use different random states for each generator
self.generators: List[ProductGenerator] = []
for i, template in enumerate(self.templates):
seed = random_state + i if random_state is not None else None
self.generators.append(
ProductGenerator(template, random_state=seed, **kwargs)
)
[docs]
def generate(
self,
n_samples: int = 1000,
target: Optional[str] = None,
samples_per_template: Optional[List[int]] = None,
train_ratio: float = 0.8,
shuffle: bool = True,
include_template_labels: bool = False,
) -> "SpectroDataset":
"""
Generate combined dataset from multiple templates.
Args:
n_samples: Total number of samples to generate.
target: Component to use as regression target.
Must exist in all templates.
samples_per_template: Number of samples per template.
If None, divides equally.
train_ratio: Proportion of samples for training partition.
shuffle: Whether to shuffle samples across templates.
include_template_labels: If True, adds template index as metadata.
Returns:
SpectroDataset combining samples from all templates.
Example:
>>> gen = CategoryGenerator(["milk_variable_fat", "meat_variable_fat"])
>>> dataset = gen.generate(n_samples=2000, target="lipid")
"""
from nirs4all.data.dataset import SpectroDataset
# Determine samples per template
if samples_per_template is None:
n_templates = len(self.templates)
base_samples = n_samples // n_templates
samples_per_template = [base_samples] * n_templates
# Add remainder to last template
samples_per_template[-1] += n_samples % n_templates
# Collect data from all templates
all_X: List[np.ndarray] = []
all_y: List[np.ndarray] = []
all_template_ids: List[np.ndarray] = []
wavelengths = None
for i, (gen, n) in enumerate(zip(self.generators, samples_per_template)):
# Use target from first template if not specified
t = target if target else gen.template.default_target
# Check if target exists in this template
if t not in gen.template.component_names:
raise ValueError(
f"Target component '{t}' not found in template '{gen.template.name}'. "
f"Available components: {gen.template.component_names}"
)
# Generate with 100% train ratio (we'll split later)
dataset, concentrations = gen.generate(
n_samples=n,
target=t,
train_ratio=1.0,
return_concentrations=True,
)
# Get X
X = dataset.x({}, layout="2d")
# Get y from concentrations (using the target index for this template)
target_idx = gen.template.component_names.index(t)
y = concentrations[:, target_idx]
all_X.append(X)
all_y.append(y)
all_template_ids.append(np.full(n, i))
if wavelengths is None:
wavelengths = dataset.wavelengths_nm()
# Concatenate
X_combined = np.vstack(all_X)
y_combined = np.concatenate(all_y)
template_ids = np.concatenate(all_template_ids)
# Shuffle if requested
if shuffle:
indices = self.rng.permutation(len(X_combined))
X_combined = X_combined[indices]
y_combined = y_combined[indices]
template_ids = template_ids[indices]
# Split train/test
n_total = len(X_combined)
n_train = int(n_total * train_ratio)
# Create dataset
dataset = SpectroDataset(name="synthetic_category")
# Create wavelength headers
headers = [str(int(wl)) for wl in wavelengths]
# Add training samples
train_meta = {"partition": "train"}
if include_template_labels:
train_meta["template_id"] = template_ids[:n_train]
dataset.add_samples(
X_combined[:n_train],
indexes=train_meta,
headers=headers,
header_unit="nm",
)
dataset.add_targets(y_combined[:n_train])
# Add test samples
test_meta = {"partition": "test"}
if include_template_labels:
test_meta["template_id"] = template_ids[n_train:]
dataset.add_samples(
X_combined[n_train:],
indexes=test_meta,
headers=headers,
header_unit="nm",
)
dataset.add_targets(y_combined[n_train:])
return dataset
[docs]
def __repr__(self) -> str:
"""Return string representation."""
template_names = [t.name for t in self.templates]
return f"CategoryGenerator(templates={template_names})"
# =============================================================================
# Convenience Functions
# =============================================================================
[docs]
def list_product_templates(
category: Optional[str] = None,
domain: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> List[str]:
"""
List available product templates with optional filtering.
Args:
category: Filter by category (e.g., "dairy", "grain", "pharma").
domain: Filter by domain (e.g., "food", "agriculture", "pharmaceutical").
tags: Filter by tags (any match).
Returns:
Sorted list of template names matching the criteria.
Example:
>>> # List all templates
>>> all_templates = list_product_templates()
>>>
>>> # List dairy templates
>>> dairy = list_product_templates(category="dairy")
>>>
>>> # List NN training templates
>>> nn_templates = list_product_templates(tags=["nn_training"])
"""
results = []
for name, template in PRODUCT_TEMPLATES.items():
if category and template.category != category:
continue
if domain and template.domain != domain:
continue
if tags:
if not any(t in template.tags for t in tags):
continue
results.append(name)
return sorted(results)
[docs]
def get_product_template(name: str) -> ProductTemplate:
"""
Get a product template by name.
Args:
name: Template name.
Returns:
ProductTemplate object.
Raises:
ValueError: If template name is not found.
Example:
>>> template = get_product_template("milk_variable_fat")
>>> print(template.description)
Milk with variable fat content (skim to whole)
"""
if name not in PRODUCT_TEMPLATES:
available = list(PRODUCT_TEMPLATES.keys())
raise ValueError(f"Unknown product template: '{name}'. Available: {available}")
return PRODUCT_TEMPLATES[name]
[docs]
def generate_product_samples(
template: Union[str, ProductTemplate],
n_samples: int = 1000,
target: Optional[str] = None,
random_state: Optional[int] = None,
**kwargs: Any,
) -> "SpectroDataset":
"""
Generate synthetic product samples (convenience function).
This is a shorthand for creating a ProductGenerator and calling generate().
Args:
template: Template name or ProductTemplate object.
n_samples: Number of samples to generate.
target: Component to use as regression target.
random_state: Random seed for reproducibility.
**kwargs: Additional arguments passed to ProductGenerator.generate().
Returns:
SpectroDataset with synthetic samples.
Example:
>>> from nirs4all.data.synthetic import generate_product_samples
>>>
>>> # Generate milk samples
>>> dataset = generate_product_samples(
... "milk_variable_fat",
... n_samples=1000,
... target="lipid",
... random_state=42
... )
"""
generator = ProductGenerator(template, random_state=random_state)
return generator.generate(n_samples=n_samples, target=target, **kwargs)
[docs]
def product_template_info(name: str) -> str:
"""
Return formatted information about a product template.
Args:
name: Template name.
Returns:
Human-readable string with template details.
Example:
>>> print(product_template_info("wheat_variable_protein"))
"""
template = get_product_template(name)
return template.info()
[docs]
def list_product_categories() -> List[str]:
"""
List all unique product categories.
Returns:
Sorted list of category names.
Example:
>>> categories = list_product_categories()
>>> print(categories)
['dairy', 'fruit', 'grain', 'legume', 'meat', 'nn_training', 'solid_dosage']
"""
categories = set()
for template in PRODUCT_TEMPLATES.values():
categories.add(template.category)
return sorted(categories)
[docs]
def list_product_domains() -> List[str]:
"""
List all unique product domains.
Returns:
Sorted list of domain names.
Example:
>>> domains = list_product_domains()
>>> print(domains)
['agriculture', 'food', 'pharmaceutical']
"""
domains = set()
for template in PRODUCT_TEMPLATES.values():
domains.add(template.domain)
return sorted(domains)