Source code for nirs4all.controllers.models.jax_wrapper

"""
JAX Model Wrapper - Wrapper for Flax models to support pickling and prediction.
"""
import numpy as np
from typing import Any, Dict

[docs] class JaxModelWrapper: """Wrapper to hold Flax model definition and trained state.""" def __init__(self, model, state): self.model = model self.state = state
[docs] def predict(self, X): variables = {'params': self.state.params} if self.state.batch_stats is not None: variables['batch_stats'] = self.state.batch_stats logits = self.state.apply_fn(variables, X, train=False) return np.array(logits)
def __getstate__(self): # For pickling return {'model': self.model, 'state': self.state} def __setstate__(self, state): self.model = state['model'] self.state = state['state']