"""
GPU-accelerated generation for synthetic NIRS data.
This module provides optional GPU acceleration for generating large
synthetic datasets using JAX, CuPy, or falls back to NumPy.
Phase 4 Features:
- Automatic backend detection (JAX, CuPy, NumPy)
- Batch spectrum generation on GPU
- Significant speedup for large datasets (10x+)
- Graceful fallback to CPU when GPU unavailable
Note:
This module is optional. GPU acceleration requires additional
dependencies (jax[cuda] or cupy-cuda*).
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
# ============================================================================
# Backend Detection
# ============================================================================
[docs]
class AcceleratorBackend(str, Enum):
"""Available acceleration backends."""
JAX = "jax"
CUPY = "cupy"
NUMPY = "numpy" # CPU fallback
def _check_jax_available() -> bool:
"""Check if JAX with GPU is available."""
try:
import jax
import jax.numpy as jnp
# Check for GPU
devices = jax.devices()
return any(d.platform == 'gpu' for d in devices)
except ImportError:
return False
except Exception:
return False
def _check_cupy_available() -> bool:
"""Check if CuPy is available."""
try:
import cupy as cp
# Try a simple operation
cp.array([1, 2, 3])
return True
except ImportError:
return False
except Exception:
return False
[docs]
def detect_best_backend() -> AcceleratorBackend:
"""
Detect the best available acceleration backend.
Returns:
AcceleratorBackend enum indicating best available option.
Example:
>>> backend = detect_best_backend()
>>> print(f"Using backend: {backend}")
"""
if _check_jax_available():
return AcceleratorBackend.JAX
elif _check_cupy_available():
return AcceleratorBackend.CUPY
else:
return AcceleratorBackend.NUMPY
[docs]
def get_backend_info() -> Dict[str, Any]:
"""
Get detailed information about available backends.
Returns:
Dictionary with backend availability and details.
"""
info = {
"jax_available": _check_jax_available(),
"cupy_available": _check_cupy_available(),
"best_backend": detect_best_backend().value,
}
if info["jax_available"]:
try:
import jax
info["jax_version"] = jax.__version__
info["jax_devices"] = [str(d) for d in jax.devices()]
except Exception:
pass
if info["cupy_available"]:
try:
import cupy as cp
info["cupy_version"] = cp.__version__
info["cuda_version"] = cp.cuda.runtime.runtimeGetVersion()
except Exception:
pass
return info
# ============================================================================
# Abstract Accelerator Interface
# ============================================================================
[docs]
@dataclass
class AcceleratedArrays:
"""Container for accelerated array operations."""
backend: AcceleratorBackend
# Core array creation
zeros: Callable
ones: Callable
arange: Callable
linspace: Callable
array: Callable
# Operations
exp: Callable
log: Callable
sqrt: Callable
sin: Callable
cos: Callable
sum: Callable
dot: Callable
matmul: Callable
# Random
random_normal: Callable
random_uniform: Callable
# Transfer
to_numpy: Callable
def _create_numpy_arrays() -> AcceleratedArrays:
"""Create NumPy-based array operations."""
rng = np.random.default_rng()
return AcceleratedArrays(
backend=AcceleratorBackend.NUMPY,
zeros=np.zeros,
ones=np.ones,
arange=np.arange,
linspace=np.linspace,
array=np.array,
exp=np.exp,
log=np.log,
sqrt=np.sqrt,
sin=np.sin,
cos=np.cos,
sum=np.sum,
dot=np.dot,
matmul=np.matmul,
random_normal=lambda shape: rng.standard_normal(shape),
random_uniform=lambda shape: rng.uniform(size=shape),
to_numpy=lambda x: np.asarray(x),
)
def _create_jax_arrays(seed: int = 0) -> AcceleratedArrays:
"""Create JAX-based array operations."""
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(seed)
def random_normal(shape):
nonlocal key
key, subkey = random.split(key)
return random.normal(subkey, shape)
def random_uniform(shape):
nonlocal key
key, subkey = random.split(key)
return random.uniform(subkey, shape)
return AcceleratedArrays(
backend=AcceleratorBackend.JAX,
zeros=jnp.zeros,
ones=jnp.ones,
arange=jnp.arange,
linspace=jnp.linspace,
array=jnp.array,
exp=jnp.exp,
log=jnp.log,
sqrt=jnp.sqrt,
sin=jnp.sin,
cos=jnp.cos,
sum=jnp.sum,
dot=jnp.dot,
matmul=jnp.matmul,
random_normal=random_normal,
random_uniform=random_uniform,
to_numpy=lambda x: np.asarray(x),
)
def _create_cupy_arrays(seed: int = 0) -> AcceleratedArrays:
"""Create CuPy-based array operations."""
import cupy as cp
cp.random.seed(seed)
return AcceleratedArrays(
backend=AcceleratorBackend.CUPY,
zeros=cp.zeros,
ones=cp.ones,
arange=cp.arange,
linspace=cp.linspace,
array=cp.array,
exp=cp.exp,
log=cp.log,
sqrt=cp.sqrt,
sin=cp.sin,
cos=cp.cos,
sum=cp.sum,
dot=cp.dot,
matmul=cp.matmul,
random_normal=lambda shape: cp.random.standard_normal(shape),
random_uniform=lambda shape: cp.random.uniform(size=shape),
to_numpy=lambda x: cp.asnumpy(x),
)
[docs]
def create_accelerated_arrays(
backend: Optional[AcceleratorBackend] = None,
seed: int = 0,
) -> AcceleratedArrays:
"""
Create accelerated array operations for the specified backend.
Args:
backend: Backend to use (auto-detect if None).
seed: Random seed.
Returns:
AcceleratedArrays with operations for the backend.
"""
if backend is None:
backend = detect_best_backend()
if backend == AcceleratorBackend.JAX:
return _create_jax_arrays(seed)
elif backend == AcceleratorBackend.CUPY:
return _create_cupy_arrays(seed)
else:
return _create_numpy_arrays()
# ============================================================================
# Accelerated Generation Functions
# ============================================================================
[docs]
def generate_voigt_profiles_accelerated(
wavelengths: np.ndarray,
centers: np.ndarray,
amplitudes: np.ndarray,
sigmas: np.ndarray,
gammas: np.ndarray,
arrays: Optional[AcceleratedArrays] = None,
) -> np.ndarray:
"""
Generate Voigt profiles using GPU acceleration.
Uses Pseudo-Voigt approximation for efficiency.
Args:
wavelengths: Wavelength array (n_wavelengths,).
centers: Band centers (n_bands,).
amplitudes: Band amplitudes (n_bands,).
sigmas: Gaussian widths (n_bands,).
gammas: Lorentzian widths (n_bands,).
arrays: Accelerated arrays (auto-create if None).
Returns:
Spectrum array (n_wavelengths,).
"""
if arrays is None:
arrays = create_accelerated_arrays()
# Transfer to device
wl = arrays.array(wavelengths)
c = arrays.array(centers)
a = arrays.array(amplitudes)
s = arrays.array(sigmas)
g = arrays.array(gammas)
# Initialize output
spectrum = arrays.zeros(len(wavelengths))
# Generate each band (vectorized over wavelengths)
for i in range(len(centers)):
# Pseudo-Voigt mixing parameter
f_G = 1.0 / (1.0 + g[i] / (s[i] + 1e-10))
# Gaussian component
gaussian = arrays.exp(-0.5 * ((wl - c[i]) / (s[i] + 1e-10)) ** 2)
# Lorentzian component
lorentzian = 1.0 / (1.0 + ((wl - c[i]) / (g[i] + 1e-10)) ** 2)
# Pseudo-Voigt
spectrum = spectrum + a[i] * (f_G * gaussian + (1 - f_G) * lorentzian)
return arrays.to_numpy(spectrum)
[docs]
def generate_spectra_batch_accelerated(
n_samples: int,
wavelengths: np.ndarray,
component_spectra: np.ndarray,
concentrations: np.ndarray,
noise_level: float = 0.01,
arrays: Optional[AcceleratedArrays] = None,
) -> np.ndarray:
"""
Generate batch of spectra using GPU acceleration.
Args:
n_samples: Number of samples to generate.
wavelengths: Wavelength array.
component_spectra: Pure component spectra (n_components, n_wavelengths).
concentrations: Concentration matrix (n_samples, n_components).
noise_level: Noise level as fraction of signal.
arrays: Accelerated arrays.
Returns:
Generated spectra (n_samples, n_wavelengths).
"""
if arrays is None:
arrays = create_accelerated_arrays()
# Transfer to device
E = arrays.array(component_spectra)
C = arrays.array(concentrations)
# Beer-Lambert mixing: X = C @ E
X = arrays.matmul(C, E)
# Add noise
noise = arrays.random_normal((n_samples, len(wavelengths)))
noise = noise * noise_level * (arrays.sqrt(arrays.sum(X ** 2, axis=1, keepdims=True)) / len(wavelengths))
X = X + noise
return arrays.to_numpy(X)
# ============================================================================
# High-Level Accelerated Generator
# ============================================================================
[docs]
class AcceleratedGenerator:
"""
GPU-accelerated synthetic spectrum generator.
This class provides a high-level interface for generating large
batches of synthetic spectra using GPU acceleration when available.
Args:
backend: Backend to use (auto-detect if None).
random_state: Random state for reproducibility.
Example:
>>> gen = AcceleratedGenerator(random_state=42)
>>> print(f"Using backend: {gen.backend}")
>>>
>>> # Generate 10000 spectra
>>> X = gen.generate_batch(
... n_samples=10000,
... wavelengths=np.linspace(1000, 2500, 700),
... component_spectra=E,
... concentrations=C,
... )
"""
def __init__(
self,
backend: Optional[AcceleratorBackend] = None,
random_state: Optional[int] = None,
):
self.backend = backend or detect_best_backend()
self.random_state = random_state or 0
self.arrays = create_accelerated_arrays(self.backend, self.random_state)
[docs]
def generate_batch(
self,
n_samples: int,
wavelengths: np.ndarray,
component_spectra: np.ndarray,
concentrations: np.ndarray,
noise_level: float = 0.01,
) -> np.ndarray:
"""
Generate a batch of spectra.
Args:
n_samples: Number of samples.
wavelengths: Wavelength array.
component_spectra: Component spectra (n_components, n_wavelengths).
concentrations: Concentrations (n_samples, n_components).
noise_level: Noise level.
Returns:
Generated spectra (n_samples, n_wavelengths).
"""
return generate_spectra_batch_accelerated(
n_samples=n_samples,
wavelengths=wavelengths,
component_spectra=component_spectra,
concentrations=concentrations,
noise_level=noise_level,
arrays=self.arrays,
)
[docs]
def generate_voigt_profiles(
self,
wavelengths: np.ndarray,
centers: np.ndarray,
amplitudes: np.ndarray,
sigmas: np.ndarray,
gammas: np.ndarray,
) -> np.ndarray:
"""
Generate Voigt profiles for component spectra.
Args:
wavelengths: Wavelength array.
centers: Band centers.
amplitudes: Band amplitudes.
sigmas: Gaussian widths.
gammas: Lorentzian widths.
Returns:
Spectrum array.
"""
return generate_voigt_profiles_accelerated(
wavelengths=wavelengths,
centers=centers,
amplitudes=amplitudes,
sigmas=sigmas,
gammas=gammas,
arrays=self.arrays,
)
# ============================================================================
# Convenience Functions
# ============================================================================
[docs]
def is_gpu_available() -> bool:
"""
Check if GPU acceleration is available.
Returns:
True if JAX with GPU or CuPy is available.
Example:
>>> if is_gpu_available():
... print("GPU acceleration enabled!")
"""
return detect_best_backend() != AcceleratorBackend.NUMPY
[docs]
def get_acceleration_speedup_estimate(n_samples: int) -> float:
"""
Estimate speedup from GPU acceleration.
Args:
n_samples: Number of samples to generate.
Returns:
Estimated speedup factor (1.0 for CPU).
"""
backend = detect_best_backend()
if backend == AcceleratorBackend.NUMPY:
return 1.0
# Empirical estimates based on typical workloads
if n_samples < 100:
return 0.5 # GPU overhead dominates for small batches
elif n_samples < 1000:
return 2.0
elif n_samples < 10000:
return 5.0
else:
return 10.0 # Large batches benefit most
[docs]
def benchmark_backends(
n_samples: int = 1000,
n_wavelengths: int = 700,
n_components: int = 5,
n_trials: int = 5,
) -> Dict[str, float]:
"""
Benchmark available backends.
Args:
n_samples: Number of samples to generate.
n_wavelengths: Number of wavelengths.
n_components: Number of components.
n_trials: Number of timing trials.
Returns:
Dictionary of backend name to mean time in seconds.
Example:
>>> results = benchmark_backends()
>>> for backend, time in results.items():
... print(f"{backend}: {time:.4f}s")
"""
import time
results = {}
# Generate test data
wavelengths = np.linspace(1000, 2500, n_wavelengths)
component_spectra = np.random.randn(n_components, n_wavelengths)
concentrations = np.abs(np.random.randn(n_samples, n_components))
# Test NumPy (always available)
arrays = _create_numpy_arrays()
times = []
for _ in range(n_trials):
start = time.perf_counter()
_ = generate_spectra_batch_accelerated(
n_samples, wavelengths, component_spectra, concentrations,
arrays=arrays
)
times.append(time.perf_counter() - start)
results["numpy"] = np.mean(times)
# Test JAX if available
if _check_jax_available():
arrays = _create_jax_arrays()
# Warmup
_ = generate_spectra_batch_accelerated(
n_samples, wavelengths, component_spectra, concentrations,
arrays=arrays
)
times = []
for _ in range(n_trials):
start = time.perf_counter()
_ = generate_spectra_batch_accelerated(
n_samples, wavelengths, component_spectra, concentrations,
arrays=arrays
)
times.append(time.perf_counter() - start)
results["jax"] = np.mean(times)
# Test CuPy if available
if _check_cupy_available():
arrays = _create_cupy_arrays()
# Warmup
_ = generate_spectra_batch_accelerated(
n_samples, wavelengths, component_spectra, concentrations,
arrays=arrays
)
times = []
for _ in range(n_trials):
start = time.perf_counter()
_ = generate_spectra_batch_accelerated(
n_samples, wavelengths, component_spectra, concentrations,
arrays=arrays
)
times.append(time.perf_counter() - start)
results["cupy"] = np.mean(times)
return results