nirs4all.controllers.models.jax_model module
JAX Model Controller - Controller for JAX/Flax models
This controller handles JAX models (specifically Flax) with support for: - Training on JAX arrays - Custom training loops with Optax optimizers - Integration with Optuna for hyperparameter tuning - Model persistence and prediction storage
Matches Flax Module objects and model configurations.
Lazy loading pattern: JAX is only imported when actually needed for training or prediction, not at module import time.
- class nirs4all.controllers.models.jax_model.JaxModelController[source]
Bases:
BaseModelControllerController for JAX/Flax models.
Uses lazy loading pattern - JAX is only imported when training or prediction is actually performed.
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute JAX model controller.
- get_preferred_layout() str[source]
Return the preferred data layout for JAX models.
Flax Dense layers expect (batch, features). Flax Conv layers expect (batch, length, features) i.e. (N, L, C). So ‘3d_transpose’ is suitable for Conv1D.