"""
TensorFlow Model Controller - Controller for TensorFlow/Keras models
This controller handles TensorFlow/Keras models with support for:
- Training on 2D/3D data with proper tensor formatting
- Model compilation with loss functions and metrics
- Early stopping and callbacks support
- Integration with Optuna for hyperparameter tuning
- Model persistence and prediction storage
Matches TensorFlow/Keras model objects and model configurations.
Lazy loading pattern: TensorFlow is only imported when actually needed
for training or prediction, not at module import time.
"""
from __future__ import annotations
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from nirs4all.data.predictions import Predictions
from nirs4all.pipeline.config.context import ExecutionContext
from nirs4all.pipeline.steps.parser import ParsedStep
from nirs4all.pipeline.runner import PipelineRunner
from nirs4all.data.dataset import SpectroDataset
try:
from tensorflow import keras
except ImportError:
pass
from ..models.base_model import BaseModelController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.core.task_type import TaskType
from nirs4all.utils.backend import is_available, require_backend, is_gpu_available
logger = get_logger(__name__)
# Fast availability check at module level - no imports
TENSORFLOW_AVAILABLE = is_available('tensorflow')
# Lazy-loaded module cache
_tf_modules: Dict[str, Any] = {}
def _get_tf():
"""Lazy load TensorFlow with caching."""
if 'tf' not in _tf_modules:
require_backend('tensorflow', feature='TensorFlow neural networks')
import tensorflow as tf
_tf_modules['tf'] = tf
_tf_modules['keras'] = tf.keras
return _tf_modules['tf']
def _get_keras():
"""Lazy load Keras with caching."""
if 'keras' not in _tf_modules:
_get_tf() # This will populate the cache
return _tf_modules['keras']
[docs]
@register_controller
class TensorFlowModelController(BaseModelController):
"""Controller for TensorFlow/Keras models.
This controller manages the complete lifecycle of TensorFlow/Keras models including:
- Model instantiation from various configuration formats
- Data preparation with proper tensor formatting (2D/3D)
- Model compilation with task-appropriate loss functions and metrics
- Training with callbacks (early stopping, model checkpointing)
- Hyperparameter tuning via Optuna integration
- Model evaluation and prediction
- Binary serialization for model persistence
The controller automatically detects TensorFlow models and functions decorated
with @framework('tensorflow'). It uses lazy loading to avoid importing
TensorFlow until actually needed.
Attributes:
priority (int): Controller priority for matching (4).
"""
priority = 4 # Higher priority than Sklearn (6)
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Determine if this controller should handle the given step.
Matches TensorFlow/Keras models, functions decorated with @framework('tensorflow'),
and serialized model configurations containing TensorFlow components.
Args:
step: Pipeline step configuration (dict, model instance, or function).
operator: Optional operator instance extracted from step.
keyword: Optional keyword identifier for the step.
Returns:
True if this controller should handle the step, False otherwise.
Returns False immediately if TensorFlow is not installed.
"""
if not TENSORFLOW_AVAILABLE:
return False
# Check if step contains a TensorFlow model or function
if isinstance(step, dict) and 'model' in step:
model = step['model']
if cls._is_tensorflow_model_or_function(model):
return True
# Handle dictionary config for model
if isinstance(model, dict) and 'class' in model:
class_name = model['class']
if isinstance(class_name, str) and ('tensorflow' in class_name or 'keras' in class_name):
return True
# Check direct TensorFlow objects or functions
if cls._is_tensorflow_model_or_function(step):
return True
# Check operator if provided
if operator is not None and cls._is_tensorflow_model_or_function(operator):
return True
return False
@classmethod
def _is_tensorflow_model(cls, obj: Any) -> bool:
"""Check if object is a TensorFlow/Keras model instance.
Uses module introspection first to avoid importing TensorFlow
for non-TensorFlow objects.
Args:
obj: Object to check.
Returns:
True if object is a keras.Model, keras.Sequential, or has fit/predict/compile
methods characteristic of Keras models. False otherwise or if TensorFlow
is not available.
"""
if not TENSORFLOW_AVAILABLE:
return False
if obj is None:
return False
# Quick check via module name first (no import needed)
module = getattr(type(obj), '__module__', '')
if 'tensorflow' not in module and 'keras' not in module:
# Also check for fit/predict/compile which are keras signatures
if not (hasattr(obj, 'fit') and hasattr(obj, 'predict') and hasattr(obj, 'compile')):
return False
try:
keras = _get_keras()
return (isinstance(obj, keras.Model) or
isinstance(obj, keras.Sequential))
except Exception:
return False
@classmethod
def _is_tensorflow_model_or_function(cls, obj: Any) -> bool:
"""Check if object is a TensorFlow model, function, or serialized configuration.
Recognizes:
- TensorFlow/Keras model instances
- Functions decorated with @framework('tensorflow')
- Serialized function dictionaries with 'function' key pointing to TensorFlow code
Args:
obj: Object to check (model, function, dict, or other).
Returns:
True if object is TensorFlow-related, False otherwise.
"""
if not TENSORFLOW_AVAILABLE:
return False
# Check if it's a TensorFlow model instance
if cls._is_tensorflow_model(obj):
return True
# Check if it's a function decorated with @framework('tensorflow')
if callable(obj) and hasattr(obj, 'framework'):
return obj.framework == 'tensorflow'
# Check if it's a serialized function dictionary
if isinstance(obj, dict):
# Check for type='function' format
if obj.get('type') == 'function' and obj.get('framework') == 'tensorflow':
return True
if 'function' in obj:
function_val = obj['function']
# Case 1: function is a callable object
if callable(function_val):
if hasattr(function_val, 'framework'):
return function_val.framework == 'tensorflow'
# Fallback: check module name
module_name = getattr(function_val, '__module__', '')
return 'tensorflow' in module_name or 'keras' in module_name
# Case 2: function is a string path
if isinstance(function_val, str):
function_path = function_val
# Try to import the function and check its framework
try:
mod_name, _, func_name = function_path.rpartition(".")
mod = __import__(mod_name, fromlist=[func_name])
func = getattr(mod, func_name)
return hasattr(func, 'framework') and func.framework == 'tensorflow'
except (ImportError, AttributeError, ValueError):
# If we can't import, check the path for tensorflow indicators
return 'tensorflow' in function_path.lower() or 'tf' in function_path.lower()
return False
def _get_model_instance(self, dataset: 'SpectroDataset', model_config: Dict[str, Any], force_params: Optional[Dict[str, Any]] = None) -> Any:
"""Create TensorFlow model instance from configuration.
Delegates to ModelFactory which handles various input formats:
- Model instances (returns as-is)
- Callables/functions (calls with input_dim, num_classes parameters)
- Serialized configs ({'function': path, 'params': {...}})
- File paths to saved models
The factory automatically extracts input dimensions from the dataset and
injects them as parameters when calling model factory functions.
Args:
dataset: SpectroDataset providing input dimensions and task type.
model_config: Model configuration (dict, instance, callable, or string path).
force_params: Optional parameters to override in the model configuration.
Returns:
Instantiated and compiled TensorFlow/Keras model ready for training.
Raises:
ImportError: If TensorFlow is not installed.
"""
require_backend('tensorflow', feature='TensorFlow models')
# Import factory here to avoid circular imports at module level
from .factory import ModelFactory
# Delegate entirely to ModelFactory
model = ModelFactory.build_single_model(
model_config,
dataset,
force_params or {}
)
return model
def _train_model(
self,
model: Any,
X_train: np.ndarray,
y_train: np.ndarray,
X_val: Optional[np.ndarray] = None,
y_val: Optional[np.ndarray] = None,
**kwargs
) -> Any:
"""Train TensorFlow/Keras model with comprehensive configuration.
Training pipeline:
1. Detects task type (regression/classification) from targets
2. Prepares compilation config (loss, optimizer, metrics)
3. Compiles the model
4. Prepares fit config (epochs, batch_size, callbacks, validation_split)
5. Trains the model with model.fit()
6. Logs training results if verbose > 1
Uses modular configuration components:
- TensorFlowCompilationConfig: Handles optimizer, loss, metrics
- TensorFlowFitConfig: Handles callbacks, validation, early stopping
Args:
model: Uncompiled or compiled TensorFlow/Keras model.
X_train: Training features as numpy array.
y_train: Training targets as numpy array.
X_val: Optional validation features.
y_val: Optional validation targets.
**kwargs: Training parameters including:
- epochs (int): Number of training epochs (default: 100)
- batch_size (int): Batch size for training (default: 32)
- patience (int): Early stopping patience (default: 10)
- verbose (int): Verbosity level (0-3)
- learning_rate (float): Optimizer learning rate
- loss (str): Loss function name
- optimizer (str): Optimizer name
- metrics (List[str]): Evaluation metrics
Returns:
Trained model with history attribute attached.
Raises:
ImportError: If TensorFlow is not installed.
"""
require_backend('tensorflow', feature='TensorFlow training')
# Import TensorFlow-specific utilities here (lazy)
from .tensorflow import (
TensorFlowCompilationConfig,
TensorFlowFitConfig,
)
train_params = kwargs
verbose = train_params.get('verbose', 0)
# Get task type from train_params (passed by base controller)
task_type = train_params.get('task_type')
if task_type is None:
raise ValueError("task_type must be provided in train_params")
if not is_gpu_available('tensorflow') and verbose > 0:
logger.warning("No GPU detected. Training TensorFlow model on CPU may be slow.")
# 1. Prepare compilation configuration
compile_config = TensorFlowCompilationConfig.prepare(train_params, task_type)
model.compile(**compile_config)
if verbose > 2:
print(f" Compilation config: {compile_config}")
# 2. Prepare fit configuration (includes callbacks)
fit_config = TensorFlowFitConfig.prepare(train_params, X_val, y_val, verbose)
# 3. Train the model
history = model.fit(
X_train, y_train,
**fit_config
)
# Store training history
model.history = history
# 4. Log training results
if verbose > 1:
self._log_training_results(model, X_train, y_train, X_val, y_val, task_type)
return model
def _log_training_results(
self,
model: Any,
X_train: np.ndarray,
y_train: np.ndarray,
X_val: Optional[np.ndarray],
y_val: Optional[np.ndarray],
task_type: TaskType
) -> None:
"""Log training and validation performance scores.
Args:
model: Trained TensorFlow model.
X_train: Training features.
y_train: Training targets.
X_val: Validation features (optional).
y_val: Validation targets (optional).
task_type: Task type enum for score calculation.
"""
print("\n Training completed - Evaluating performance:")
# Training scores
y_train_pred = self._predict_model(model, X_train)
train_scores = self._calculate_and_print_scores(
y_train, y_train_pred, task_type, partition="train",
model_name=model.__class__.__name__, show_detailed_scores=False
)
# Validation scores if available
if X_val is not None and y_val is not None:
y_val_pred = self._predict_model(model, X_val)
val_scores = self._calculate_and_print_scores(
y_val, y_val_pred, task_type, partition="validation",
model_name=model.__class__.__name__, show_detailed_scores=False
)
# Show comparison
primary_metric = 'accuracy' if task_type.is_classification else 'r2'
train_score = train_scores.get(primary_metric, 0)
val_score = val_scores.get(primary_metric, 0)
if task_type.is_classification:
print(f"\n Accuracy: Train={train_score:.4f} | Val={val_score:.4f}")
else:
print(f"\n R² Score: Train={train_score:.4f} | Val={val_score:.4f}")
# Warn about overfitting
if train_score - val_score > 0.15: # 15% difference
print(f" ⚠️ Warning: Possible overfitting detected (Train-Val diff: {train_score - val_score:.4f})")
def _predict_model(self, model: Any, X: np.ndarray) -> np.ndarray:
"""Generate predictions with TensorFlow model.
Handles output format normalization:
- Multiclass classification (2D with >1 columns): Converts softmax probabilities
to class indices via argmax
- Regression/binary (1D or 2D with 1 column): Ensures column vector shape
Args:
model: Trained TensorFlow/Keras model.
X: Input features as numpy array (will be prepared for TensorFlow format).
Returns:
Predictions as 2D numpy array (n_samples, 1) for regression/binary,
or class indices (0 to n_classes-1) for multiclass classification.
"""
# Prepare data to ensure correct shape for model
X_prepared, _ = self._prepare_data(X, None, {})
predictions = model.predict(X_prepared, verbose=0)
# For multiclass classification, convert probabilities to class indices
if predictions.ndim == 2 and predictions.shape[1] > 1:
# Multi-output: likely multiclass classification with softmax
# Convert probabilities to class predictions (encoded labels 0-N)
predictions = np.argmax(predictions, axis=1).reshape(-1, 1).astype(np.float32)
elif predictions.ndim == 1:
# Single output: reshape to column vector
predictions = predictions.reshape(-1, 1)
return predictions
def _predict_proba_model(self, model: Any, X: np.ndarray) -> Optional[np.ndarray]:
"""Get class probabilities from TensorFlow classification model.
Returns raw softmax/sigmoid outputs before argmax conversion.
For binary classification with single sigmoid output, converts to
2-column format [1-p, p].
Args:
model: Trained TensorFlow/Keras model.
X: Input features as numpy array.
Returns:
Class probabilities as (n_samples, n_classes) array,
or None for regression models.
"""
# Prepare data to ensure correct shape for model
X_prepared, _ = self._prepare_data(X, None, {})
predictions = model.predict(X_prepared, verbose=0)
# For regression (single output), return None
if predictions.ndim == 1:
predictions = predictions.reshape(-1, 1)
if predictions.shape[1] == 1:
# Binary classification with sigmoid: convert to 2-column format
# Or could be regression - check if values are in [0, 1] range
if np.all((predictions >= 0) & (predictions <= 1)):
return np.column_stack([1 - predictions, predictions])
else:
# Regression output, not probabilities
return None
# Multiclass: already has probability distribution
return predictions
def _prepare_data(
self,
X: np.ndarray,
y: Optional[np.ndarray],
context: Any
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Prepare data for TensorFlow with proper tensor formatting.
Delegates to TensorFlowDataPreparation which handles:
- Conversion to float32
- 2D data: (samples, features) → (samples, features, 1) for Conv1D
- 3D data: Conditional transpose if shape[1] < shape[2]
(e.g., (samples, 3, 200) → (samples, 200, 3) for Conv1D)
- Target flattening for 2D targets with single column
Args:
X: Input features as numpy array (2D or 3D).
y: Optional target values as numpy array.
context: Execution context.
Returns:
Tuple of (prepared_X, prepared_y) in TensorFlow-compatible formats.
"""
# Import here to avoid loading TensorFlow at module import time
from .tensorflow import TensorFlowDataPreparation
return TensorFlowDataPreparation.prepare_data(X, y, context)
def _evaluate_model(self, model: Any, X_val: np.ndarray, y_val: np.ndarray) -> float:
"""Evaluate TensorFlow model on validation data.
Uses model.evaluate() to compute loss. Falls back to MSE calculation
from predictions if evaluation fails.
Args:
model: Trained TensorFlow/Keras model.
X_val: Validation features.
y_val: Validation targets.
Returns:
Loss value as float. Returns float('inf') if evaluation fails completely.
"""
try:
# Use model's evaluate method
loss = model.evaluate(X_val, y_val, verbose=0)
# If evaluate returns list (loss + metrics), take the loss
if isinstance(loss, list):
return loss[0]
else:
return loss
except (ValueError, TypeError, AttributeError) as e:
logger.warning(f"Error in TensorFlow model evaluation: {e}")
try:
# Fallback: use predictions and calculate MSE
y_pred = model.predict(X_val, verbose=0)
mse = np.mean((y_val - y_pred) ** 2)
return float(mse)
except (ValueError, TypeError, AttributeError):
return float('inf')
[docs]
def get_preferred_layout(self) -> str:
"""Return the preferred data layout for TensorFlow models.
TensorFlow Conv1D expects input shape (features, channels) where:
- features = number of wavelengths/spectral points (timesteps for convolution)
- channels = number of preprocessing methods
The '3d_transpose' layout returns (samples, features, processings) which is correct for Conv1D.
"""
return "3d_transpose"
def _clone_model(self, model: Any) -> Any:
"""Clone TensorFlow model using framework-specific cloning method.
Implements the abstract _clone_model from BaseModelController.
Cloning strategy:
- Callable functions: Return as-is (will be called with proper input_shape later)
- Keras model instances: Use keras.models.clone_model() to create fresh copy
- Other objects: Return as-is for ModelFactory to handle
Args:
model: Model instance, function, or configuration to clone.
Returns:
Cloned model (for instances) or original object (for functions/configs).
"""
if callable(model) and hasattr(model, 'framework') and model.framework == 'tensorflow':
# Don't clone functions - they will be called later with proper input shape
return model
if TENSORFLOW_AVAILABLE:
keras = _get_keras()
if isinstance(model, (keras.Model, keras.Sequential)):
# TensorFlow model instance: use clone_model
return keras.models.clone_model(model)
# Return as is (will be handled by ModelFactory)
return model
# Remove the _extract_model_config override - use base class implementation
# The base class correctly returns {'model_instance': operator, 'train_params': {...}}
# and ModelFactory now handles 'model_instance' key properly
[docs]
def process_hyperparameters(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Process hyperparameters for TensorFlow model tuning.
Supports TensorFlow-specific parameter organization:
- Parameters prefixed with 'compile_' are grouped under 'compile' key
(e.g., 'compile_learning_rate' → compile['learning_rate'])
- Parameters prefixed with 'fit_' are grouped under 'fit' key
(e.g., 'fit_batch_size' → fit['batch_size'])
- Other parameters are treated as model architecture parameters
Args:
params: Dictionary of sampled parameters.
Returns:
Dictionary of processed hyperparameters with proper nesting for
TensorFlow compilation and fitting.
"""
tf_params = {}
for key, value in params.items():
if key.startswith('compile_'):
# Parameters for model compilation
compile_key = key.replace('compile_', '')
if 'compile' not in tf_params:
tf_params['compile'] = {}
tf_params['compile'][compile_key] = value
elif key.startswith('fit_'):
# Parameters for model fitting
fit_key = key.replace('fit_', '')
if 'fit' not in tf_params:
tf_params['fit'] = {}
tf_params['fit'][fit_key] = value
else:
# Model architecture parameters
tf_params[key] = value
return tf_params if tf_params else params
[docs]
def execute(
self,
step_info: 'ParsedStep',
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, bytes]]] = None,
prediction_store: 'Predictions' = None
) -> Tuple[ExecutionContext, List[Tuple[str, bytes]]]:
"""Execute TensorFlow model training, finetuning, or prediction.
Sets the preferred data layout to '3d_transpose' for TensorFlow Conv1D models,
then delegates to the base class execute method.
Args:
step_info: Parsed step containing model configuration and operator.
dataset: SpectroDataset with features, targets, and fold information.
context: Execution context with step_id, processing history, partition info.
runtime_context: Runtime context managing execution state.
source: Data source index (default: -1 for primary source).
mode: Execution mode - 'train', 'finetune', 'predict', or 'explain'.
loaded_binaries: Optional list of (name, bytes) tuples for prediction mode,
containing serialized model and preprocessing artifacts.
prediction_store: External Predictions storage instance for managing
prediction results across pipeline steps.
Returns:
Tuple of (updated_context, list_of_artifact_metadata) where:
- updated_context: Context dict with added model information
- artifact_metadata: List of serialized binary artifacts for persistence
Raises:
ImportError: If TensorFlow is not installed.
"""
if not TENSORFLOW_AVAILABLE:
raise ImportError(
"TensorFlow is not available. Please install with: "
"pip install nirs4all[tensorflow] or pip install nirs4all[gpu]"
)
# Set layout preference for TensorFlow models (force_layout overrides preferred)
context = context.with_layout(self.get_effective_layout(step_info))
# Call parent execute method
return super().execute(step_info, dataset, context, runtime_context, source, mode, loaded_binaries, prediction_store)