"""
TensorFlow Configuration Management
This module provides classes for managing TensorFlow model configuration:
- Compilation configuration (optimizer, loss, metrics)
- Fit configuration (epochs, batch_size, validation)
- Callback factory for creating various callbacks
"""
from typing import Any, Dict, List, Optional, TYPE_CHECKING
import numpy as np
from nirs4all.core.task_type import TaskType
from ..utilities import ModelControllerUtils as ModelUtils
from nirs4all.utils.backend import is_available, require_backend
from nirs4all.core.logging import get_logger
logger = get_logger(__name__)
# Fast availability check at module level - no imports
TF_AVAILABLE = is_available('tensorflow')
if TYPE_CHECKING:
try:
import tensorflow as tf
from tensorflow import keras
except ImportError:
pass
[docs]
class TensorFlowCompilationConfig:
"""Manages TensorFlow model compilation configuration."""
[docs]
@staticmethod
def prepare(train_params: Dict[str, Any], task_type: TaskType) -> Dict[str, Any]:
"""Prepare compilation configuration from training parameters.
Args:
train_params: Dictionary with training parameters (may include 'compile' key).
task_type: TaskType enum indicating classification or regression.
Returns:
Dictionary with 'optimizer', 'loss', and 'metrics' keys.
"""
require_backend('tensorflow', feature='TensorFlow compilation')
# Start with defaults
compile_config = {
'optimizer': 'adam',
'loss': 'mse',
'metrics': ['mae']
}
# Handle nested compile parameters
if 'compile' in train_params:
compile_config.update(train_params['compile'])
# Handle flat parameters (for convenience)
flat_compile_params = {}
for key in ['optimizer', 'loss', 'metrics', 'learning_rate', 'lr']:
if key in train_params and key not in ['compile']:
flat_compile_params[key] = train_params[key]
compile_config.update(flat_compile_params)
# Auto-configure loss and metrics based on task type if not explicitly set
if 'loss' not in train_params and 'compile' not in train_params:
compile_config['loss'] = ModelUtils.get_default_loss(task_type, framework='tensorflow')
if 'metrics' not in train_params and 'compile' not in train_params:
compile_config['metrics'] = ModelUtils.get_default_metrics(task_type, framework='tensorflow')
# Handle optimizer configuration with learning rate
compile_config = TensorFlowCompilationConfig._configure_optimizer(compile_config, train_params)
return compile_config
@staticmethod
def _configure_optimizer(compile_config: Dict[str, Any], train_params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Configure optimizer with learning rate if provided.
Args:
compile_config: Configuration dictionary with 'optimizer' and possibly 'learning_rate'.
train_params: Training parameters to check for cyclic_lr.
Returns:
Updated configuration with optimizer instance if learning_rate was provided.
"""
if not TF_AVAILABLE:
return compile_config
optimizer = compile_config.get('optimizer', 'adam')
learning_rate = compile_config.pop('learning_rate', None)
# Also check for 'lr' as alias
if learning_rate is None:
learning_rate = compile_config.pop('lr', None)
# Check if cyclic_lr is enabled
is_cyclic = train_params and train_params.get('cyclic_lr', False)
# If optimizer is string and (we have learning_rate OR cyclic_lr is enabled), create optimizer instance
if isinstance(optimizer, str):
if learning_rate is not None:
compile_config['optimizer'] = TensorFlowCompilationConfig.create_optimizer(
optimizer, learning_rate
)
elif is_cyclic:
# If cyclic LR is enabled but no initial LR provided, use default (e.g. 0.001)
# but explicitly create instance so we have a variable LR
compile_config['optimizer'] = TensorFlowCompilationConfig.create_optimizer(
optimizer, 0.001 # Default LR
)
return compile_config
[docs]
@staticmethod
def create_optimizer(optimizer_name: str, learning_rate: float) -> 'keras.optimizers.Optimizer':
"""Create optimizer instance with learning rate.
Args:
optimizer_name: Name of optimizer ('adam', 'sgd', 'rmsprop', etc.).
learning_rate: Learning rate value.
Returns:
Configured optimizer instance.
"""
require_backend('tensorflow', feature='TensorFlow optimizer')
from tensorflow import keras
optimizer_classes = {
'adam': keras.optimizers.Adam,
'sgd': keras.optimizers.SGD,
'rmsprop': keras.optimizers.RMSprop,
'adagrad': keras.optimizers.Adagrad,
'adadelta': keras.optimizers.Adadelta,
'adamax': keras.optimizers.Adamax,
'nadam': keras.optimizers.Nadam,
}
optimizer_class = optimizer_classes.get(optimizer_name.lower())
if optimizer_class is None:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
return optimizer_class(learning_rate=learning_rate)
[docs]
class TensorFlowFitConfig:
"""Manages TensorFlow model fit configuration."""
[docs]
@staticmethod
def prepare(
train_params: Dict[str, Any],
X_val: Optional[np.ndarray],
y_val: Optional[np.ndarray],
verbose: int = 0
) -> Dict[str, Any]:
"""Prepare fit configuration including validation setup.
Args:
train_params: Dictionary with training parameters (may include 'fit' key).
X_val: Validation features (optional).
y_val: Validation targets (optional).
verbose: Verbosity level for logging.
Returns:
Dictionary with fit parameters including 'callbacks'.
"""
# Start with defaults
fit_config = {
'epochs': 100,
'batch_size': 32,
'validation_split': 0.2,
'verbose': 1
}
# Handle nested fit parameters
if 'fit' in train_params:
fit_config.update(train_params['fit'])
# Handle flat parameters (for convenience)
flat_fit_params = {}
for param in ['epochs', 'batch_size', 'validation_split', 'verbose']:
if param in train_params and param not in ['fit', 'compile']:
flat_fit_params[param] = train_params[param]
fit_config.update(flat_fit_params)
# Handle validation data vs validation split
if X_val is not None and y_val is not None:
fit_config['validation_data'] = (X_val, y_val)
fit_config.pop('validation_split', None)
# Configure callbacks
fit_config['callbacks'] = TensorFlowCallbackFactory.create_callbacks(
train_params,
fit_config.get('callbacks', []),
verbose
)
return fit_config
[docs]
class TensorFlowCallbackFactory:
"""Factory for creating TensorFlow callbacks."""
[docs]
@staticmethod
def create_callbacks(
train_params: Dict[str, Any],
existing_callbacks: List[Any],
verbose: int = 0
) -> List[Any]:
"""Create comprehensive callback system.
Args:
train_params: Training parameters with callback configuration.
existing_callbacks: List of existing callback instances.
verbose: Verbosity level for logging.
Returns:
List of callback instances.
"""
if not TF_AVAILABLE:
return existing_callbacks
from tensorflow import keras
callbacks = list(existing_callbacks)
# === EARLY STOPPING ===
if not any(isinstance(cb, keras.callbacks.EarlyStopping) for cb in callbacks):
early_stopping_params = train_params.get('early_stopping', {})
if early_stopping_params or train_params.get('patience'):
callbacks.append(
TensorFlowCallbackFactory.create_early_stopping(train_params, verbose)
)
# === CYCLIC LEARNING RATE ===
if train_params.get('cyclic_lr', False):
callbacks.append(
TensorFlowCallbackFactory.create_cyclic_lr(train_params, verbose)
)
# === REDUCE LR ON PLATEAU ===
if train_params.get('reduce_lr_on_plateau', False):
callbacks.append(
TensorFlowCallbackFactory.create_reduce_lr_on_plateau(train_params, verbose)
)
# === BEST MODEL MEMORY ===
if train_params.get('best_model_memory', True):
callbacks.append(
TensorFlowCallbackFactory.create_best_model_memory(verbose)
)
# === CUSTOM CALLBACKS ===
if 'custom_callbacks' in train_params:
custom_cbs = train_params['custom_callbacks']
if not isinstance(custom_cbs, list):
custom_cbs = [custom_cbs]
callbacks.extend(custom_cbs)
return callbacks
[docs]
@staticmethod
def create_early_stopping(train_params: Dict[str, Any], verbose: int = 0) -> 'keras.callbacks.EarlyStopping':
"""Create early stopping callback.
Args:
train_params: Training parameters with early_stopping config.
verbose: Verbosity level.
Returns:
EarlyStopping callback instance.
"""
require_backend('tensorflow', feature='TensorFlow callbacks')
from tensorflow import keras
early_stopping_params = train_params.get('early_stopping', {})
# Handle flat patience parameter
if 'patience' in train_params and 'patience' not in early_stopping_params:
early_stopping_params['patience'] = train_params['patience']
es_config = {
'monitor': 'val_loss',
'patience': 10,
'restore_best_weights': True,
'verbose': 1 if verbose > 0 else 0
}
es_config.update(early_stopping_params)
return keras.callbacks.EarlyStopping(**es_config)
[docs]
@staticmethod
def create_cyclic_lr(train_params: Dict[str, Any], verbose: int = 0) -> 'keras.callbacks.Callback':
"""Create cyclic learning rate callback.
Args:
train_params: Training parameters with cyclic_lr config.
verbose: Verbosity level.
Returns:
Custom cyclic LR callback instance.
"""
require_backend('tensorflow', feature='TensorFlow callbacks')
from tensorflow import keras
cyclic_lr_params = train_params.get('cyclic_lr_params', {})
class CyclicLR(keras.callbacks.Callback):
def __init__(self, base_lr=0.001, max_lr=0.006, step_size=2000, mode='triangular', verbose=0):
super().__init__()
self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
self.mode = mode
self.verbose = verbose
self.clr_iterations = 0
self.history = {}
def on_train_begin(self, logs=None):
if self.verbose > 0:
logger.debug(f"Cyclic LR: base={self.base_lr}, max={self.max_lr}, step_size={self.step_size}")
def on_batch_end(self, batch, logs=None):
self.clr_iterations += 1
cycle = np.floor(1 + self.clr_iterations / (2 * self.step_size))
x = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1)
lr = self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - x))
try:
if hasattr(self.model.optimizer, 'learning_rate'):
keras.backend.set_value(self.model.optimizer.learning_rate, lr)
except (AttributeError, TypeError) as e:
if self.verbose > 0 and self.clr_iterations == 1:
logger.warning(f"Could not set learning rate for CyclicLR: {e}")
return CyclicLR(
base_lr=cyclic_lr_params.get('base_lr', 0.001),
max_lr=cyclic_lr_params.get('max_lr', 0.006),
step_size=cyclic_lr_params.get('step_size', 2000),
mode=cyclic_lr_params.get('mode', 'triangular'),
verbose=verbose
)
[docs]
@staticmethod
def create_reduce_lr_on_plateau(train_params: Dict[str, Any], verbose: int = 0) -> 'keras.callbacks.ReduceLROnPlateau':
"""Create reduce LR on plateau callback.
Args:
train_params: Training parameters with reduce_lr_on_plateau config.
verbose: Verbosity level.
Returns:
ReduceLROnPlateau callback instance.
"""
require_backend('tensorflow', feature='TensorFlow callbacks')
from tensorflow import keras
reduce_lr_params = train_params.get('reduce_lr_on_plateau_params', {})
rlr_config = {
'monitor': 'val_loss',
'factor': 0.5,
'patience': 5,
'min_lr': 1e-7,
'verbose': 1 if verbose > 0 else 0
}
rlr_config.update(reduce_lr_params)
return keras.callbacks.ReduceLROnPlateau(**rlr_config)
[docs]
@staticmethod
def create_best_model_memory(verbose: int = 0) -> 'keras.callbacks.Callback':
"""Create best model memory callback (saves best weights during training).
Args:
verbose: Verbosity level.
Returns:
Custom best model memory callback instance.
"""
require_backend('tensorflow', feature='TensorFlow callbacks')
from tensorflow import keras
class BestModelMemory(keras.callbacks.Callback):
def __init__(self, verbose=0):
super().__init__()
self.best_weights = None
self.best_loss = float('inf')
self.verbose = verbose
def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get('val_loss', logs.get('loss'))
if current_loss < self.best_loss:
self.best_loss = current_loss
self.best_weights = self.model.get_weights()
if self.verbose > 1:
logger.debug(f"Best model saved at epoch {epoch + 1} with loss {current_loss:.4f}")
def on_train_end(self, logs=None):
if self.best_weights is not None:
self.model.set_weights(self.best_weights)
if self.verbose > 0:
logger.info(f"Restored best model with loss {self.best_loss:.4f}")
return BestModelMemory(verbose)