nirs4all.controllers.models.tensorflow_model module
TensorFlow Model Controller - Controller for TensorFlow/Keras models
This controller handles TensorFlow/Keras models with support for: - Training on 2D/3D data with proper tensor formatting - Model compilation with loss functions and metrics - Early stopping and callbacks support - Integration with Optuna for hyperparameter tuning - Model persistence and prediction storage
Matches TensorFlow/Keras model objects and model configurations.
Lazy loading pattern: TensorFlow is only imported when actually needed for training or prediction, not at module import time.
- class nirs4all.controllers.models.tensorflow_model.TensorFlowModelController[source]
Bases:
BaseModelControllerController for TensorFlow/Keras models.
This controller manages the complete lifecycle of TensorFlow/Keras models including: - Model instantiation from various configuration formats - Data preparation with proper tensor formatting (2D/3D) - Model compilation with task-appropriate loss functions and metrics - Training with callbacks (early stopping, model checkpointing) - Hyperparameter tuning via Optuna integration - Model evaluation and prediction - Binary serialization for model persistence
The controller automatically detects TensorFlow models and functions decorated with @framework(‘tensorflow’). It uses lazy loading to avoid importing TensorFlow until actually needed.
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute TensorFlow model training, finetuning, or prediction.
Sets the preferred data layout to ‘3d_transpose’ for TensorFlow Conv1D models, then delegates to the base class execute method.
- Parameters:
step_info – Parsed step containing model configuration and operator.
dataset – SpectroDataset with features, targets, and fold information.
context – Execution context with step_id, processing history, partition info.
runtime_context – Runtime context managing execution state.
source – Data source index (default: -1 for primary source).
mode – Execution mode - ‘train’, ‘finetune’, ‘predict’, or ‘explain’.
loaded_binaries – Optional list of (name, bytes) tuples for prediction mode, containing serialized model and preprocessing artifacts.
prediction_store – External Predictions storage instance for managing prediction results across pipeline steps.
- Returns:
updated_context: Context dict with added model information
artifact_metadata: List of serialized binary artifacts for persistence
- Return type:
Tuple of (updated_context, list_of_artifact_metadata) where
- Raises:
ImportError – If TensorFlow is not installed.
- get_preferred_layout() str[source]
Return the preferred data layout for TensorFlow models.
TensorFlow Conv1D expects input shape (features, channels) where: - features = number of wavelengths/spectral points (timesteps for convolution) - channels = number of preprocessing methods
The ‘3d_transpose’ layout returns (samples, features, processings) which is correct for Conv1D.
- classmethod matches(step: Any, operator: Any, keyword: str) bool[source]
Determine if this controller should handle the given step.
Matches TensorFlow/Keras models, functions decorated with @framework(‘tensorflow’), and serialized model configurations containing TensorFlow components.
- Parameters:
step – Pipeline step configuration (dict, model instance, or function).
operator – Optional operator instance extracted from step.
keyword – Optional keyword identifier for the step.
- Returns:
True if this controller should handle the step, False otherwise. Returns False immediately if TensorFlow is not installed.
- process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]
Process hyperparameters for TensorFlow model tuning.
Supports TensorFlow-specific parameter organization: - Parameters prefixed with ‘compile_’ are grouped under ‘compile’ key
(e.g., ‘compile_learning_rate’ → compile[‘learning_rate’])
Parameters prefixed with ‘fit_’ are grouped under ‘fit’ key (e.g., ‘fit_batch_size’ → fit[‘batch_size’])
Other parameters are treated as model architecture parameters
- Parameters:
params – Dictionary of sampled parameters.
- Returns:
Dictionary of processed hyperparameters with proper nesting for TensorFlow compilation and fitting.