"""
JAX Model Controller - Controller for JAX/Flax models
This controller handles JAX models (specifically Flax) with support for:
- Training on JAX arrays
- Custom training loops with Optax optimizers
- Integration with Optuna for hyperparameter tuning
- Model persistence and prediction storage
Matches Flax Module objects and model configurations.
Lazy loading pattern: JAX is only imported when actually needed
for training or prediction, not at module import time.
"""
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING
import numpy as np
import copy
from ..models.base_model import BaseModelController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.utils.backend import is_available, require_backend, is_gpu_available
logger = get_logger(__name__)
# Fast availability check at module level - no imports
JAX_AVAILABLE = is_available('jax')
if TYPE_CHECKING:
from nirs4all.pipeline.runner import PipelineRunner
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.config.context import ExecutionContext
from nirs4all.pipeline.steps.parser import ParsedStep
try:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state
except ImportError:
pass
# Lazy-loaded module cache
_jax_modules: Dict[str, Any] = {}
def _get_jax():
"""Lazy load JAX with caching."""
if 'jax' not in _jax_modules:
require_backend('jax', feature='JAX/Flax neural networks')
import jax
import jax.numpy as jnp
_jax_modules['jax'] = jax
_jax_modules['jnp'] = jnp
return _jax_modules['jax']
def _get_flax():
"""Lazy load Flax with caching."""
if 'flax' not in _jax_modules:
_get_jax() # Ensure JAX is loaded first
import flax.linen as nn
_jax_modules['flax'] = nn
return _jax_modules['flax']
def _get_optax():
"""Lazy load Optax with caching."""
if 'optax' not in _jax_modules:
_get_jax() # Ensure JAX is loaded first
import optax
_jax_modules['optax'] = optax
return _jax_modules['optax']
[docs]
@register_controller
class JaxModelController(BaseModelController):
"""Controller for JAX/Flax models.
Uses lazy loading pattern - JAX is only imported when
training or prediction is actually performed.
"""
priority = 4 # Higher priority than Sklearn (6)
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Match JAX models and model configurations."""
if not JAX_AVAILABLE:
return False
# Check if step contains a JAX model
if isinstance(step, dict) and 'model' in step:
model = step['model']
if cls._is_jax_model(model):
return True
# Handle dictionary config for model
if isinstance(model, dict) and 'class' in model:
class_name = model['class']
if isinstance(class_name, str) and ('jax' in class_name or 'flax' in class_name):
return True
# Check direct JAX objects
if cls._is_jax_model(step):
return True
# Check operator if provided
if operator is not None and cls._is_jax_model(operator):
return True
return False
@classmethod
def _is_jax_model(cls, obj: Any) -> bool:
"""Check if object is a JAX/Flax model.
Uses module introspection first to avoid importing JAX
for non-JAX objects.
"""
if not JAX_AVAILABLE:
return False
if obj is None:
return False
# Check for framework attribute first (no import needed)
if hasattr(obj, 'framework') and obj.framework == 'jax':
return True
# Check for dict format from deserialize_component
if isinstance(obj, dict) and obj.get('type') == 'function' and obj.get('framework') == 'jax':
return True
# Quick check via module name (no import needed)
module = getattr(type(obj), '__module__', '')
if 'jax' not in module and 'flax' not in module:
return False
try:
nn = _get_flax()
return isinstance(obj, nn.Module)
except Exception:
return False
def _get_model_instance(self, dataset: 'SpectroDataset', model_config: Dict[str, Any], force_params: Optional[Dict[str, Any]] = None) -> Any:
"""Create JAX model instance from configuration."""
require_backend('jax', feature='JAX/Flax models')
# Import factory here to avoid circular imports at module level
from .factory import ModelFactory
return ModelFactory.build_single_model(
model_config,
dataset,
force_params or {}
)
def _create_train_state(self, rng, model, input_shape, learning_rate):
"""Create initial training state."""
jnp = _jax_modules['jnp']
optax = _get_optax()
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: Any
variables = model.init(rng, jnp.ones(input_shape))
params = variables['params']
batch_stats = variables.get('batch_stats')
tx = optax.adam(learning_rate)
return TrainState.create(
apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats
)
def _train_model(
self,
model: Any,
X_train: Any,
y_train: Any,
X_val: Optional[Any] = None,
y_val: Optional[Any] = None,
**kwargs
) -> Any:
"""Train JAX model with custom training loop."""
require_backend('jax', feature='JAX/Flax training')
# Import JAX modules here (lazy loading)
jax = _get_jax()
jnp = _jax_modules['jnp']
optax = _get_optax()
# Import wrapper here (lazy)
from .jax_wrapper import JaxModelWrapper
train_params = kwargs
verbose = train_params.get('verbose', 0)
if not is_gpu_available('jax') and verbose > 0:
logger.warning("No GPU detected. Training JAX model on CPU may be slow.")
epochs = train_params.get('epochs', 100)
batch_size = train_params.get('batch_size', 32)
learning_rate = train_params.get('lr', train_params.get('learning_rate', 0.001))
# Initialize RNG
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
# Create TrainState
# Input shape: (1, features) or (1, features, channels)
input_shape = (1,) + X_train.shape[1:]
state = self._create_train_state(init_rng, model, input_shape, learning_rate)
# Define loss function (MSE for regression, CrossEntropy for classification)
task_type = train_params.get('task_type')
is_classification = task_type and task_type.is_classification
@jax.jit
def train_step(state, batch_X, batch_y, rng):
dropout_rng = rng
def loss_fn(params):
variables = {'params': params}
if state.batch_stats is not None:
variables['batch_stats'] = state.batch_stats
mutable = ['batch_stats'] if state.batch_stats is not None else []
rngs = {'dropout': dropout_rng}
if mutable:
logits, new_model_state = state.apply_fn(
variables, batch_X, train=True, mutable=mutable, rngs=rngs
)
else:
logits = state.apply_fn(
variables, batch_X, train=True, rngs=rngs
)
new_model_state = None
if is_classification:
# Handle classification loss
if batch_y.ndim == 1 or (batch_y.ndim == 2 and batch_y.shape[1] == 1):
# Integer labels
labels = batch_y.squeeze().astype(jnp.int32)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
else:
# One-hot labels
loss = optax.softmax_cross_entropy(logits, batch_y)
loss = jnp.mean(loss)
else:
# Simple MSE loss for regression
loss = jnp.mean((logits - batch_y) ** 2)
return loss, new_model_state
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_model_state), grads = grad_fn(state.params)
new_batch_stats = state.batch_stats
if new_model_state is not None and 'batch_stats' in new_model_state:
new_batch_stats = new_model_state['batch_stats']
state = state.apply_gradients(grads=grads, batch_stats=new_batch_stats)
return state, loss
@jax.jit
def eval_step(state, batch_X, batch_y):
variables = {'params': state.params}
if state.batch_stats is not None:
variables['batch_stats'] = state.batch_stats
logits = state.apply_fn(variables, batch_X, train=False)
if is_classification:
if batch_y.ndim == 1 or (batch_y.ndim == 2 and batch_y.shape[1] == 1):
labels = batch_y.squeeze().astype(jnp.int32)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
else:
loss = optax.softmax_cross_entropy(logits, batch_y)
loss = jnp.mean(loss)
else:
loss = jnp.mean((logits - batch_y) ** 2)
return loss
# Training loop
n_samples = X_train.shape[0]
steps_per_epoch = n_samples // batch_size
best_val_loss = float('inf')
best_params = None
best_batch_stats = None
patience = train_params.get('patience', 10)
patience_counter = 0
for epoch in range(epochs):
# Shuffle data
rng, shuffle_rng = jax.random.split(rng)
perms = jax.random.permutation(shuffle_rng, n_samples)
X_train_shuffled = X_train[perms]
y_train_shuffled = y_train[perms]
epoch_loss = 0.0
for i in range(steps_per_epoch):
batch_idx = slice(i * batch_size, (i + 1) * batch_size)
batch_X = X_train_shuffled[batch_idx]
batch_y = y_train_shuffled[batch_idx]
rng, step_rng = jax.random.split(rng)
state, loss = train_step(state, batch_X, batch_y, step_rng)
epoch_loss += loss
epoch_loss /= steps_per_epoch
# Validation
if X_val is not None and y_val is not None:
val_loss = eval_step(state, X_val, y_val)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_params = state.params
best_batch_stats = state.batch_stats
patience_counter = 0
else:
patience_counter += 1
if verbose > 1 and (epoch + 1) % 10 == 0:
logger.debug(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
if patience_counter >= patience:
if verbose > 0:
logger.info(f"Early stopping at epoch {epoch+1}")
break
else:
if verbose > 1 and (epoch + 1) % 10 == 0:
logger.debug(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}")
# Restore best params
if best_params is not None:
state = state.replace(params=best_params, batch_stats=best_batch_stats)
# Attach state to model wrapper for prediction
# Since Flax models are stateless, we need to return a wrapper that holds the state
return JaxModelWrapper(model, state)
def _predict_model(self, model: Any, X: Any) -> np.ndarray:
"""Generate predictions with JAX model."""
# Import wrapper here (lazy)
from .jax_wrapper import JaxModelWrapper
if isinstance(model, JaxModelWrapper):
preds = model.predict(X)
# Handle multiclass classification (convert logits/probs to labels)
if preds.ndim == 2 and preds.shape[1] > 1:
return np.argmax(preds, axis=-1).reshape(-1, 1)
# Ensure 2D shape for regression/binary
if preds.ndim == 1:
return preds.reshape(-1, 1)
return preds
else:
raise ValueError("Model must be a JaxModelWrapper instance for prediction")
def _predict_proba_model(self, model: Any, X: Any) -> Optional[np.ndarray]:
"""Get class probabilities from JAX classification model.
Returns softmax probabilities for classification models.
Args:
model: Trained JAX model (JaxModelWrapper).
X: Input features.
Returns:
Class probabilities as (n_samples, n_classes) array,
or None for regression models.
"""
jax = _get_jax()
import jax.nn as jnn
# Import wrapper here (lazy)
from .jax_wrapper import JaxModelWrapper
if not isinstance(model, JaxModelWrapper):
return None
# Get raw model outputs (logits)
preds = model.predict(X)
if preds.ndim == 1:
preds = preds.reshape(-1, 1)
if preds.shape[1] == 1:
# Binary classification with single output
# Apply sigmoid and convert to 2-column format
probs = jnn.sigmoid(preds)
probs = np.asarray(probs)
return np.column_stack([1 - probs, probs])
else:
# Multiclass: apply softmax
probs = jnn.softmax(preds, axis=-1)
return np.asarray(probs)
def _prepare_data(
self,
X: np.ndarray,
y: Optional[np.ndarray],
context: 'ExecutionContext'
) -> Tuple[Any, Optional[Any]]:
"""Prepare data for JAX."""
# Import here to avoid loading JAX at module import time
from .jax.data_prep import JaxDataPreparation
return JaxDataPreparation.prepare_data(X, y)
def _evaluate_model(self, model: Any, X_val: Any, y_val: Any) -> float:
"""Evaluate JAX model."""
# Import wrapper here (lazy)
from .jax_wrapper import JaxModelWrapper
if isinstance(model, JaxModelWrapper):
predictions = model.predict(X_val)
# Calculate MSE manually on numpy arrays
mse = np.mean((predictions - y_val) ** 2)
return float(mse)
return float('inf')
[docs]
def get_preferred_layout(self) -> str:
"""Return the preferred data layout for JAX models.
Flax Dense layers expect (batch, features).
Flax Conv layers expect (batch, length, features) i.e. (N, L, C).
So '3d_transpose' is suitable for Conv1D.
"""
return "3d_transpose"
def _clone_model(self, model: Any) -> Any:
"""Clone JAX model."""
# Flax models are immutable dataclasses, so we can just return the model definition
# The state is created fresh in _train_model
return model
[docs]
def process_hyperparameters(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Process hyperparameters for JAX model tuning."""
# JAX implementation is simple, no complex nesting needed yet
return params
[docs]
def execute(
self,
step_info: 'ParsedStep',
dataset: 'SpectroDataset',
context: 'ExecutionContext',
runtime_context: 'RuntimeContext',
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, bytes]]] = None,
prediction_store: 'Predictions' = None
) -> Tuple['ExecutionContext', List[Tuple[str, bytes]]]:
"""Execute JAX model controller."""
if not JAX_AVAILABLE:
raise ImportError(
"JAX is not available. Please install with: "
"pip install nirs4all[jax]"
)
# Set layout preference (force_layout overrides preferred)
context = context.with_layout(self.get_effective_layout(step_info))
# Call parent execute method
return super().execute(step_info, dataset, context, runtime_context, source, mode, loaded_binaries, prediction_store)