nirs4all.controllers.models.tensorflow_model module

TensorFlow Model Controller - Controller for TensorFlow/Keras models

This controller handles TensorFlow/Keras models with support for: - Training on 2D/3D data with proper tensor formatting - Model compilation with loss functions and metrics - Early stopping and callbacks support - Integration with Optuna for hyperparameter tuning - Model persistence and prediction storage

Matches TensorFlow/Keras model objects and model configurations.

Lazy loading pattern: TensorFlow is only imported when actually needed for training or prediction, not at module import time.

class nirs4all.controllers.models.tensorflow_model.TensorFlowModelController[source]

Bases: BaseModelController

Controller for TensorFlow/Keras models.

This controller manages the complete lifecycle of TensorFlow/Keras models including: - Model instantiation from various configuration formats - Data preparation with proper tensor formatting (2D/3D) - Model compilation with task-appropriate loss functions and metrics - Training with callbacks (early stopping, model checkpointing) - Hyperparameter tuning via Optuna integration - Model evaluation and prediction - Binary serialization for model persistence

The controller automatically detects TensorFlow models and functions decorated with @framework(‘tensorflow’). It uses lazy loading to avoid importing TensorFlow until actually needed.

priority

Controller priority for matching (4).

Type:

int

execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]

Execute TensorFlow model training, finetuning, or prediction.

Sets the preferred data layout to ‘3d_transpose’ for TensorFlow Conv1D models, then delegates to the base class execute method.

Parameters:
  • step_info – Parsed step containing model configuration and operator.

  • dataset – SpectroDataset with features, targets, and fold information.

  • context – Execution context with step_id, processing history, partition info.

  • runtime_context – Runtime context managing execution state.

  • source – Data source index (default: -1 for primary source).

  • mode – Execution mode - ‘train’, ‘finetune’, ‘predict’, or ‘explain’.

  • loaded_binaries – Optional list of (name, bytes) tuples for prediction mode, containing serialized model and preprocessing artifacts.

  • prediction_store – External Predictions storage instance for managing prediction results across pipeline steps.

Returns:

  • updated_context: Context dict with added model information

  • artifact_metadata: List of serialized binary artifacts for persistence

Return type:

Tuple of (updated_context, list_of_artifact_metadata) where

Raises:

ImportError – If TensorFlow is not installed.

get_preferred_layout() str[source]

Return the preferred data layout for TensorFlow models.

TensorFlow Conv1D expects input shape (features, channels) where: - features = number of wavelengths/spectral points (timesteps for convolution) - channels = number of preprocessing methods

The ‘3d_transpose’ layout returns (samples, features, processings) which is correct for Conv1D.

classmethod matches(step: Any, operator: Any, keyword: str) bool[source]

Determine if this controller should handle the given step.

Matches TensorFlow/Keras models, functions decorated with @framework(‘tensorflow’), and serialized model configurations containing TensorFlow components.

Parameters:
  • step – Pipeline step configuration (dict, model instance, or function).

  • operator – Optional operator instance extracted from step.

  • keyword – Optional keyword identifier for the step.

Returns:

True if this controller should handle the step, False otherwise. Returns False immediately if TensorFlow is not installed.

priority: int = 4
process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]

Process hyperparameters for TensorFlow model tuning.

Supports TensorFlow-specific parameter organization: - Parameters prefixed with ‘compile_’ are grouped under ‘compile’ key

(e.g., ‘compile_learning_rate’ → compile[‘learning_rate’])

  • Parameters prefixed with ‘fit_’ are grouped under ‘fit’ key (e.g., ‘fit_batch_size’ → fit[‘batch_size’])

  • Other parameters are treated as model architecture parameters

Parameters:

params – Dictionary of sampled parameters.

Returns:

Dictionary of processed hyperparameters with proper nesting for TensorFlow compilation and fitting.