Source code for nirs4all.controllers.models.torch_model

"""
PyTorch Model Controller - Controller for PyTorch models

This controller handles PyTorch models with support for:
- Training on tensor data with proper device management (CPU/GPU)
- Custom training loops with loss functions and optimizers
- Learning rate scheduling and model checkpointing
- Integration with Optuna for hyperparameter tuning
- Model persistence and prediction storage

Matches PyTorch nn.Module objects and model configurations.

Lazy loading pattern: PyTorch is only imported when actually needed
for training or prediction, not at module import time.
"""

from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING
import numpy as np
import copy

from ..models.base_model import BaseModelController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.utils.backend import is_available, require_backend, is_gpu_available

logger = get_logger(__name__)

# Fast availability check at module level - no imports
PYTORCH_AVAILABLE = is_available('torch')

if TYPE_CHECKING:
    from nirs4all.pipeline.runner import PipelineRunner
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext
    from nirs4all.pipeline.steps.parser import ParsedStep
    try:
        import torch
        import torch.nn as nn
        import torch.optim as optim
        from torch.utils.data import DataLoader, TensorDataset
    except ImportError:
        pass


# Lazy-loaded module cache
_torch_modules: Dict[str, Any] = {}


def _get_torch():
    """Lazy load PyTorch with caching."""
    if 'torch' not in _torch_modules:
        require_backend('torch', feature='PyTorch neural networks')
        import torch
        _torch_modules['torch'] = torch
        _torch_modules['nn'] = torch.nn
        _torch_modules['optim'] = torch.optim
    return _torch_modules['torch']


def _get_nn():
    """Lazy load torch.nn with caching."""
    if 'nn' not in _torch_modules:
        _get_torch()
    return _torch_modules['nn']


[docs] @register_controller class PyTorchModelController(BaseModelController): """Controller for PyTorch models. Uses lazy loading pattern - PyTorch is only imported when training or prediction is actually performed. """ priority = 4 # Higher priority than Sklearn (6)
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """Match PyTorch models and model configurations.""" if not PYTORCH_AVAILABLE: return False # Check if step contains a PyTorch model if isinstance(step, dict) and 'model' in step: model = step['model'] if cls._is_pytorch_model(model): return True # Handle dictionary config for model if isinstance(model, dict) and 'class' in model: class_name = model['class'] if isinstance(class_name, str) and 'torch' in class_name: return True # Check direct PyTorch objects if cls._is_pytorch_model(step): return True # Check operator if provided if operator is not None and cls._is_pytorch_model(operator): return True return False
@classmethod def _is_pytorch_model(cls, obj: Any) -> bool: """Check if object is a PyTorch model. Uses module introspection first to avoid importing PyTorch for non-PyTorch objects. """ if not PYTORCH_AVAILABLE: return False if obj is None: return False # Check for framework attribute first (no import needed) if hasattr(obj, 'framework') and obj.framework == 'pytorch': return True # Check for dict format from deserialize_component if isinstance(obj, dict) and obj.get('type') == 'function' and obj.get('framework') == 'pytorch': return True # Quick check via module name (no import needed) module = getattr(type(obj), '__module__', '') if 'torch' not in module: return False try: nn = _get_nn() return isinstance(obj, nn.Module) except Exception: return False def _get_model_instance(self, dataset: 'SpectroDataset', model_config: Dict[str, Any], force_params: Optional[Dict[str, Any]] = None) -> 'nn.Module': """Create PyTorch model instance from configuration.""" require_backend('torch', feature='PyTorch models') # Import factory here to avoid circular imports at module level from .factory import ModelFactory return ModelFactory.build_single_model( model_config, dataset, force_params or {} ) def _train_model( self, model: 'nn.Module', X_train: Any, y_train: Any, X_val: Optional[Any] = None, y_val: Optional[Any] = None, **kwargs ) -> 'nn.Module': """Train PyTorch model with custom training loop.""" require_backend('torch', feature='PyTorch training') # Import PyTorch here (lazy loading) torch = _get_torch() nn = _get_nn() optim = _torch_modules['optim'] from torch.utils.data import DataLoader, TensorDataset train_params = kwargs verbose = train_params.get('verbose', 0) if not is_gpu_available('torch') and verbose > 0: logger.warning("No GPU detected. Training PyTorch model on CPU may be slow.") # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Data is already prepared as tensors by _prepare_data, just move to device X_train = X_train.to(device) y_train = y_train.to(device) if X_val is not None: X_val = X_val.to(device) if y_val is not None: y_val = y_val.to(device) # Setup optimizer optimizer_config = train_params.get('optimizer', 'Adam') lr = train_params.get('lr', train_params.get('learning_rate', 0.001)) if isinstance(optimizer_config, str): optimizer_class = getattr(optim, optimizer_config) optimizer = optimizer_class(model.parameters(), lr=lr) elif isinstance(optimizer_config, dict): opt_type = optimizer_config.pop('type', 'Adam') optimizer_class = getattr(optim, opt_type) optimizer = optimizer_class(model.parameters(), **optimizer_config) else: optimizer = optimizer_config # Setup loss function loss_fn_config = train_params.get('loss', 'MSELoss') if isinstance(loss_fn_config, str): # Handle common loss names if loss_fn_config.lower() == 'mse': loss_fn = nn.MSELoss() elif loss_fn_config.lower() == 'mae': loss_fn = nn.L1Loss() elif loss_fn_config.lower() == 'crossentropy': loss_fn = nn.CrossEntropyLoss() elif hasattr(nn, loss_fn_config): loss_fn = getattr(nn, loss_fn_config)() else: loss_fn = nn.MSELoss() # Default else: loss_fn = loss_fn_config # Training parameters epochs = train_params.get('epochs', 100) batch_size = train_params.get('batch_size', 32) patience = train_params.get('patience', 10) # Create data loaders train_dataset = TensorDataset(X_train, y_train) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = None if X_val is not None and y_val is not None: val_dataset = TensorDataset(X_val, y_val) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # Training loop with early stopping best_val_loss = float('inf') best_model_state = None patience_counter = 0 for epoch in range(epochs): # Training phase model.train() train_loss = 0.0 for batch_X, batch_y in train_loader: optimizer.zero_grad() # Support for models that need targets during forward (e.g., FCK-PLS) if hasattr(model, 'set_targets'): model.set_targets(batch_y) outputs = model(batch_X) loss = loss_fn(outputs, batch_y) # Support for models with custom regularization if hasattr(model, 'kernel_regularization'): loss = loss + model.kernel_regularization() loss.backward() optimizer.step() train_loss += loss.item() train_loss /= len(train_loader) # Validation phase val_loss = 0.0 if val_loader is not None: model.eval() with torch.no_grad(): for batch_X, batch_y in val_loader: # Support for models that need targets during forward if hasattr(model, 'set_targets'): model.set_targets(batch_y) outputs = model(batch_X) loss = loss_fn(outputs, batch_y) val_loss += loss.item() val_loss /= len(val_loader) # Early stopping if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = copy.deepcopy(model.state_dict()) patience_counter = 0 else: patience_counter += 1 if verbose > 1 and (epoch + 1) % 10 == 0: logger.debug(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") if patience_counter >= patience: if verbose > 0: logger.info(f"Early stopping at epoch {epoch+1}") break else: if verbose > 1 and (epoch + 1) % 10 == 0: logger.debug(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}") # Load best model weights if we have them if best_model_state is not None: model.load_state_dict(best_model_state) return model def _predict_model(self, model: 'nn.Module', X: Any) -> np.ndarray: """Generate predictions with PyTorch model.""" torch = _get_torch() # Import data prep here (lazy) from .torch.data_prep import PyTorchDataPreparation device = next(model.parameters()).device # Ensure X is a tensor if not isinstance(X, torch.Tensor): X = PyTorchDataPreparation.prepare_features(X, device) else: X = X.to(device) model.eval() with torch.no_grad(): predictions = model(X) predictions = predictions.cpu().numpy() # Handle multiclass classification (convert logits/probs to labels) if predictions.ndim == 2 and predictions.shape[1] > 1: # Multi-output: likely multiclass classification with softmax/logits # Convert probabilities to class predictions (encoded labels 0-N) predictions = np.argmax(predictions, axis=1).reshape(-1, 1).astype(np.float32) elif predictions.ndim == 1: predictions = predictions.reshape(-1, 1) return predictions def _predict_proba_model(self, model: 'nn.Module', X: Any) -> Optional[np.ndarray]: """Get class probabilities from PyTorch classification model. Returns softmax probabilities for classification models. For binary classification with single output, converts to 2-column format. Args: model: Trained PyTorch model. X: Input features. Returns: Class probabilities as (n_samples, n_classes) array, or None for regression models. """ torch = _get_torch() import torch.nn.functional as F # Import data prep here (lazy) from .torch.data_prep import PyTorchDataPreparation device = next(model.parameters()).device # Ensure X is a tensor if not isinstance(X, torch.Tensor): X = PyTorchDataPreparation.prepare_features(X, device) else: X = X.to(device) model.eval() with torch.no_grad(): logits = model(X) logits = logits.cpu() # Check if this looks like classification output if logits.ndim == 1: logits = logits.unsqueeze(1) if logits.shape[1] == 1: # Binary classification with single output (sigmoid) # Check if values are in logit range or already probabilities probs = torch.sigmoid(logits) probs = probs.numpy() return np.column_stack([1 - probs, probs]) else: # Multiclass: apply softmax to get probabilities probs = F.softmax(logits, dim=1) return probs.numpy() def _prepare_data( self, X: np.ndarray, y: Optional[np.ndarray], context: 'ExecutionContext' ) -> Tuple[Any, Optional[Any]]: """Prepare data for PyTorch (convert to tensors).""" # Import here to avoid loading PyTorch at module import time from .torch.data_prep import PyTorchDataPreparation return PyTorchDataPreparation.prepare_data(X, y) def _evaluate_model(self, model: 'nn.Module', X_val: Any, y_val: Any) -> float: """Evaluate PyTorch model.""" try: torch = _get_torch() nn = _get_nn() device = next(model.parameters()).device X_val = X_val.to(device) y_val = y_val.to(device) model.eval() with torch.no_grad(): predictions = model(X_val) mse_loss = nn.MSELoss() loss = mse_loss(predictions, y_val) return loss.item() except Exception as e: logger.warning(f"Error in PyTorch model evaluation: {e}") return float('inf')
[docs] def get_preferred_layout(self) -> str: """Return the preferred data layout for PyTorch models. PyTorch typically expects (samples, channels, features) for 1D convs. We use '3d' which gives (samples, processings, features) -> (N, C, L). """ return "3d"
def _clone_model(self, model: 'nn.Module') -> 'nn.Module': """Clone PyTorch model.""" # For PyTorch, we can use deepcopy to get a fresh model with same architecture # But we need to reset parameters to ensure fresh weights cloned = copy.deepcopy(model) # Reset parameters def weight_reset(m): if hasattr(m, 'reset_parameters'): m.reset_parameters() cloned.apply(weight_reset) return cloned
[docs] def process_hyperparameters(self, params: Dict[str, Any]) -> Dict[str, Any]: """Process hyperparameters for PyTorch model tuning.""" torch_params = {} for key, value in params.items(): if key.startswith('optimizer_'): # Parameters for optimizer opt_key = key.replace('optimizer_', '') if 'optimizer' not in torch_params: torch_params['optimizer'] = {} torch_params['optimizer'][opt_key] = value else: # Model or training parameters torch_params[key] = value return torch_params if torch_params else params
[docs] def execute( self, step_info: 'ParsedStep', dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: 'RuntimeContext', source: int = -1, mode: str = "train", loaded_binaries: Optional[List[Tuple[str, bytes]]] = None, prediction_store: 'Predictions' = None ) -> Tuple['ExecutionContext', List[Tuple[str, bytes]]]: """Execute PyTorch model controller.""" if not PYTORCH_AVAILABLE: raise ImportError( "PyTorch is not available. Please install with: " "pip install nirs4all[torch]" ) # Set layout preference (force_layout overrides preferred) context = context.with_layout(self.get_effective_layout(step_info)) # Call parent execute method return super().execute(step_info, dataset, context, runtime_context, source, mode, loaded_binaries, prediction_store)