nirs4all.controllers.models.jax_wrapper module

JAX Model Wrapper - Wrapper for Flax models to support pickling and prediction.

class nirs4all.controllers.models.jax_wrapper.JaxModelWrapper(model, state)[source]

Bases: object

Wrapper to hold Flax model definition and trained state.

predict(X)[source]