nirs4all.controllers.models.base_model module

Simplified Base Model Controller - Clean, readable implementation

This is a complete rewrite following the user’s pseudo-code specification. The controller is designed to be simple, clean, and readable with the logic properly separated into 3 files maximum.

Key features: - Simple execute() method with clear train/prediction mode logic - Externalized prediction storage, model utils, and naming logic - Clean separation between training, finetuning, and prediction - Framework-specific models (sklearn, tensorflow) handle their own details

class nirs4all.controllers.models.base_model.BaseModelController[source]

Bases: OperatorController, ABC

Abstract base controller for machine learning model training and prediction.

This controller provides a unified interface for training, finetuning, and predicting with machine learning models across different frameworks (scikit-learn, TensorFlow, PyTorch). It implements cross-validation, fold averaging, hyperparameter optimization, and comprehensive prediction tracking.

The controller delegates framework-specific operations to subclasses while handling:
  • Cross-validation fold management

  • Model identification and naming

  • Prediction storage and tracking

  • Score calculation and aggregation

  • Fold-averaged predictions (simple and weighted)

optuna_manager

Manager for hyperparameter optimization.

Type:

OptunaManager

identifier_generator

Component for model naming.

Type:

ModelIdentifierGenerator

prediction_transformer

Component for prediction scaling.

Type:

PredictionTransformer

prediction_assembler

Component for assembling prediction records.

Type:

PredictionDataAssembler

score_calculator

Component for calculating evaluation scores.

Type:

ScoreCalculator

index_normalizer

Component for normalizing sample indices.

Type:

IndexNormalizer

prediction_store

External storage for predictions.

Type:

Predictions

verbose

Verbosity level for logging.

Type:

int

execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[ArtifactMeta]][source]

Execute model training, finetuning, or prediction.

This is the main entry point for model execution. It handles:
  • Extracting model configuration

  • Restoring task type in predict/explain modes

  • Delegating to finetune() or train() based on configuration

  • Managing prediction storage

Parameters:
  • step_info – Parsed step containing model configuration and operator.

  • dataset – SpectroDataset with features and targets.

  • context – Execution context with step_id, partition info, etc.

  • runtime_context – Runtime context managing execution state.

  • source – Data source index (default: -1).

  • mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’).

  • loaded_binaries – Optional list of (name, bytes) tuples for prediction mode.

  • prediction_store – External Predictions storage instance.

Returns:

Tuple of (updated_context, list_of_artifact_metadata).

finetune(dataset: SpectroDataset, model_config: Dict[str, Any], X_train: Any, y_train: Any, X_test: Any, y_test: Any, folds: List | None, finetune_params: Dict[str, Any], predictions: Dict, context: ExecutionContext, runtime_context: RuntimeContext) Dict[str, Any] | List[Dict[str, Any]][source]

Optimize hyperparameters using Optuna.

Delegates to OptunaManager for Bayesian hyperparameter optimization. Returns optimized parameters that will be used in subsequent training.

Parameters:
  • dataset – SpectroDataset for optimization.

  • model_config – Base model configuration.

  • X_train – Training features.

  • y_train – Training targets.

  • X_test – Test features.

  • y_test – Test targets.

  • folds – List of (train_idx, val_idx) tuples for cross-validation.

  • finetune_params – Optuna configuration with search space and trials.

  • predictions – Prediction storage dictionary.

  • context – Execution context.

  • runtime_context – Runtime context.

Returns:

Dictionary of optimized parameters (single model) or list of dicts (per-fold).

get_effective_layout(step_info: ParsedStep | None = None) str[source]

Get effective data layout, respecting force_layout if specified.

This method checks if the step configuration has a force_layout override. If not, it falls back to the controller’s preferred layout.

Parameters:

step_info – ParsedStep containing potential force_layout override.

Returns:

Data layout string to use for this step.

get_preferred_layout() str[source]

Get preferred data layout for the framework.

Returns:

Data layout string (‘2d’ for NumPy arrays, ‘3d’ for TensorFlow, etc.).

Note

Override in subclasses for framework-specific layouts.

get_xy(dataset: SpectroDataset, context: ExecutionContext) Tuple[Any, Any, Any, Any, Any, Any][source]

Extract train/test splits with scaled and unscaled targets.

For classification tasks, both scaled and unscaled targets are transformed. For regression tasks, scaled targets are used for training while unscaled (numeric) targets are used for evaluation.

In prediction mode, uses all available data (partition=None) instead of splitting.

Also handles sample_partitioner branches, which restrict data to a subset of samples.

Parameters:
  • dataset – SpectroDataset with partitioned data.

  • context – Execution context with partition and preprocessing info.

Returns:

Tuple of (X_train, y_train, X_test, y_test, y_train_unscaled, y_test_unscaled).

launch_training(dataset, model_config, context, runtime_context, prediction_store, X_train, y_train, X_val, y_val, X_test, y_train_unscaled, y_val_unscaled, y_test_unscaled, train_indices=None, val_indices=None, fold_idx=None, best_params=None, loaded_binaries=None, mode='train', test_sample_ids=None)[source]

Execute single model training or prediction.

This refactored method uses modular components to handle: - Model identification and naming - Model loading for predict/explain modes - Training execution - Prediction transformation - Score calculation - Prediction data assembly

Parameters:
  • dataset – SpectroDataset instance

  • model_config – Model configuration dictionary

  • context – Execution context with step_id, y processing, etc.

  • runtime_context – Runtime context.

  • prediction_store – Predictions storage instance

  • X_train – Training data (scaled)

  • y_train – Training data (scaled)

  • X_val – Validation data (scaled)

  • y_val – Validation data (scaled)

  • X_test – Test data (scaled)

  • y_train_unscaled – True values (unscaled)

  • y_val_unscaled – True values (unscaled)

  • y_test_unscaled – True values (unscaled)

  • train_indices – Sample indices for each partition

  • val_indices – Sample indices for each partition

  • fold_idx – Optional fold index for CV

  • best_params – Optional hyperparameters from optimization

  • loaded_binaries – Optional binaries for predict/explain mode

  • mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’)

  • test_sample_ids – List of actual sample IDs for test partition (for stacking).

Returns:

Tuple of (trained_model, model_id, val_score, model_name, prediction_data)

load_model(filepath: str) Any[source]

Optional: Load model from framework-specific format.

Default implementation delegates to artifact_serialization.load(). Subclasses can override to use framework-specific loading.

Parameters:

filepath – Path to load from.

Returns:

Loaded model instance.

priority: int = 15
process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]

Process hyperparameters before use.

Can be overridden by subclasses to structure parameters (e.g. nesting for TensorFlow).

Parameters:

params – Flat dictionary of sampled parameters.

Returns:

Processed dictionary of parameters.

save_model(model: Any, filepath: str) None[source]

Optional: Save model in framework-specific format.

Default implementation delegates to artifact_serialization.persist(). Subclasses can override to use framework-specific formats: - TensorFlow: .h5 or .keras format - PyTorch: .ckpt or .pt format - sklearn: .joblib format

Parameters:
  • model – Trained model to save.

  • filepath – Path to save (without extension, will be added by implementation).

classmethod supports_prediction_mode() bool[source]

Check if the controller should execute during prediction mode.

Returns:

True if the controller should execute in prediction mode, False if it should be skipped (e.g., chart controllers)

train(dataset, model_config, context, runtime_context, prediction_store, X_train, y_train, X_test, y_test, y_train_unscaled, y_test_unscaled, folds, best_params=None, loaded_binaries=None, mode='train', train_sample_ids=None, test_sample_ids=None) List[ArtifactMeta][source]

Orchestrate model training across folds with prediction tracking.

Manages the complete training workflow:
  • Iterates through cross-validation folds

  • Delegates to launch_training() for each fold

  • Creates fold-averaged predictions for regression tasks

  • Persists trained models as artifacts

  • Stores all predictions with weights

Parameters:
  • dataset – SpectroDataset with features and targets.

  • model_config – Model configuration dictionary.

  • context – Execution context with step_id and preprocessing info.

  • runtime_context – Runtime context.

  • prediction_store – External Predictions storage.

  • X_train – Training features (all folds).

  • y_train – Training targets (scaled).

  • X_test – Test features.

  • y_test – Test targets (scaled).

  • y_train_unscaled – Training targets (unscaled for evaluation).

  • y_test_unscaled – Test targets (unscaled for evaluation).

  • folds – List of (train_idx, val_idx) tuples or empty list.

  • best_params – Optional hyperparameters from finetuning.

  • loaded_binaries – Optional model binaries for prediction mode.

  • mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’).

  • train_sample_ids – List of actual sample IDs for train partition.

  • test_sample_ids – List of actual sample IDs for test partition.

Returns:

List of ArtifactMeta objects for persisted models.

classmethod use_multi_source() bool[source]

Check if the operator supports multi-source datasets.