nirs4all.controllers.models.torch_model module

PyTorch Model Controller - Controller for PyTorch models

This controller handles PyTorch models with support for: - Training on tensor data with proper device management (CPU/GPU) - Custom training loops with loss functions and optimizers - Learning rate scheduling and model checkpointing - Integration with Optuna for hyperparameter tuning - Model persistence and prediction storage

Matches PyTorch nn.Module objects and model configurations.

Lazy loading pattern: PyTorch is only imported when actually needed for training or prediction, not at module import time.

class nirs4all.controllers.models.torch_model.PyTorchModelController[source]

Bases: BaseModelController

Controller for PyTorch models.

Uses lazy loading pattern - PyTorch is only imported when training or prediction is actually performed.

execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]

Execute PyTorch model controller.

get_preferred_layout() str[source]

Return the preferred data layout for PyTorch models.

PyTorch typically expects (samples, channels, features) for 1D convs. We use ‘3d’ which gives (samples, processings, features) -> (N, C, L).

classmethod matches(step: Any, operator: Any, keyword: str) bool[source]

Match PyTorch models and model configurations.

priority: int = 4
process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]

Process hyperparameters for PyTorch model tuning.