nirs4all.controllers.models package
Subpackages
- nirs4all.controllers.models.components package
- Submodules
- nirs4all.controllers.models.components.identifier_generator module
- nirs4all.controllers.models.components.index_normalizer module
- nirs4all.controllers.models.components.prediction_assembler module
- nirs4all.controllers.models.components.prediction_transformer module
- nirs4all.controllers.models.components.score_calculator module
- Module contents
- Submodules
- nirs4all.controllers.models.stacking package
- Submodules
- nirs4all.controllers.models.stacking.branch_validator module
- nirs4all.controllers.models.stacking.classification module
- nirs4all.controllers.models.stacking.config module
- nirs4all.controllers.models.stacking.crossbranch module
- nirs4all.controllers.models.stacking.exceptions module
BranchFeatureAlignmentErrorBranchMismatchErrorBranchingErrorCircularDependencyErrorCrossBranchStackingErrorCrossPartitionStackingErrorDisjointSampleSetsErrorFeatureOrderMismatchErrorFoldMismatchAcrossBranchesErrorGeneratorSyntaxStackingWarningIncompatibleBranchSamplesErrorIncompatibleBranchTypeErrorInconsistentLevelErrorInvalidMetaModelArtifactErrorMaxStackingLevelExceededErrorMetaModelErrorMetaModelPredictionErrorMetaModelSerializationErrorMissingDependencyErrorMissingSourceModelErrorMultiLevelStackingErrorNestedBranchStackingErrorNoSourcePredictionsErrorSourcePredictionError
- nirs4all.controllers.models.stacking.multilevel module
- nirs4all.controllers.models.stacking.reconstructor module
- nirs4all.controllers.models.stacking.serialization module
- Module contents
BranchFeatureAlignmentErrorBranchInfoBranchMismatchErrorBranchPredictionInfoBranchPredictionInfo.branch_idBranchPredictionInfo.branch_nameBranchPredictionInfo.model_namesBranchPredictionInfo.sample_indicesBranchPredictionInfo.n_samplesBranchPredictionInfo.n_foldsBranchPredictionInfo.branch_typeBranchPredictionInfo.branch_idBranchPredictionInfo.branch_nameBranchPredictionInfo.branch_typeBranchPredictionInfo.model_namesBranchPredictionInfo.n_foldsBranchPredictionInfo.n_samplesBranchPredictionInfo.sample_indices
BranchTypeBranchValidationResultBranchValidatorBranchingErrorCircularDependencyErrorClassificationFeatureExtractorClassificationInfoClassificationInfo.task_typeClassificationInfo.n_classesClassificationInfo.class_labelsClassificationInfo.has_probabilitiesClassificationInfo.proba_shapeClassificationInfo.class_labelsClassificationInfo.get_n_features_per_model()ClassificationInfo.has_probabilitiesClassificationInfo.is_binaryClassificationInfo.is_classificationClassificationInfo.is_multiclassClassificationInfo.n_classesClassificationInfo.proba_shapeClassificationInfo.task_type
CrossBranchCompatibilityCrossBranchStackingErrorCrossBranchValidationResultCrossBranchValidationResult.is_compatibleCrossBranchValidationResult.compatibilityCrossBranchValidationResult.branchesCrossBranchValidationResult.common_samplesCrossBranchValidationResult.alignment_issuesCrossBranchValidationResult.warningsCrossBranchValidationResult.errorsCrossBranchValidationResult.add_error()CrossBranchValidationResult.add_warning()CrossBranchValidationResult.alignment_issuesCrossBranchValidationResult.branchesCrossBranchValidationResult.common_samplesCrossBranchValidationResult.compatibilityCrossBranchValidationResult.errorsCrossBranchValidationResult.is_compatibleCrossBranchValidationResult.total_modelsCrossBranchValidationResult.warnings
CrossBranchValidatorCrossPartitionStackingErrorDisjointSampleSetsErrorFeatureNameGeneratorFeatureOrderMismatchErrorFoldAlignmentValidatorFoldMismatchAcrossBranchesErrorGeneratorSyntaxStackingWarningIncompatibleBranchSamplesErrorIncompatibleBranchTypeErrorInconsistentLevelErrorInvalidMetaModelArtifactErrorLevelValidationResultLevelValidationResult.is_validLevelValidationResult.detected_levelLevelValidationResult.source_levelsLevelValidationResult.circular_dependenciesLevelValidationResult.warningsLevelValidationResult.errorsLevelValidationResult.add_error()LevelValidationResult.add_warning()LevelValidationResult.circular_dependenciesLevelValidationResult.detected_levelLevelValidationResult.errorsLevelValidationResult.is_validLevelValidationResult.source_levelsLevelValidationResult.warnings
MaxStackingLevelExceededErrorMetaFeatureInfoMetaFeatureInfo.feature_namesMetaFeatureInfo.source_modelsMetaFeatureInfo.feature_to_modelMetaFeatureInfo.classification_infoMetaFeatureInfo.n_features_per_modelMetaFeatureInfo.aggregate_importance_by_model()MetaFeatureInfo.classification_infoMetaFeatureInfo.feature_namesMetaFeatureInfo.feature_to_modelMetaFeatureInfo.get_model_for_feature()MetaFeatureInfo.n_features_per_modelMetaFeatureInfo.source_models
MetaModelArtifactMetaModelArtifact.meta_model_typeMetaModelArtifact.meta_model_nameMetaModelArtifact.meta_learner_classMetaModelArtifact.source_modelsMetaModelArtifact.feature_columnsMetaModelArtifact.stacking_configMetaModelArtifact.selector_configMetaModelArtifact.branch_contextMetaModelArtifact.use_probaMetaModelArtifact.n_foldsMetaModelArtifact.coverage_ratioMetaModelArtifact.artifact_idMetaModelArtifact.training_timestampMetaModelArtifact.artifact_idMetaModelArtifact.branch_contextMetaModelArtifact.coverage_ratioMetaModelArtifact.feature_columnsMetaModelArtifact.feature_to_model_mappingMetaModelArtifact.from_dict()MetaModelArtifact.from_json()MetaModelArtifact.get_source_artifact_ids()MetaModelArtifact.get_source_by_index()MetaModelArtifact.meta_learner_classMetaModelArtifact.meta_model_nameMetaModelArtifact.meta_model_typeMetaModelArtifact.n_classesMetaModelArtifact.n_foldsMetaModelArtifact.selector_configMetaModelArtifact.source_modelsMetaModelArtifact.stacking_configMetaModelArtifact.task_typeMetaModelArtifact.to_dict()MetaModelArtifact.to_json()MetaModelArtifact.training_timestampMetaModelArtifact.use_probaMetaModelArtifact.validate_feature_alignment()
MetaModelErrorMetaModelPredictionErrorMetaModelSerializationErrorMetaModelSerializerMissingDependencyErrorMissingSourceModelErrorModelLevelInfoMultiLevelStackingErrorMultiLevelValidatorMultiLevelValidator.prediction_storeMultiLevelValidator.max_levelMultiLevelValidator.log_warningsMultiLevelValidator.META_MODEL_PATTERNSMultiLevelValidator.clear_cache()MultiLevelValidator.detect_level()MultiLevelValidator.filter_by_level()MultiLevelValidator.get_all_levels()MultiLevelValidator.validate_sources()
NestedBranchStackingErrorNoSourcePredictionsErrorReconstructionResultReconstructionResult.X_train_metaReconstructionResult.X_test_metaReconstructionResult.y_trainReconstructionResult.y_testReconstructionResult.feature_namesReconstructionResult.source_modelsReconstructionResult.valid_train_maskReconstructionResult.valid_test_maskReconstructionResult.validation_resultReconstructionResult.n_foldsReconstructionResult.coverage_ratioReconstructionResult.meta_feature_infoReconstructionResult.classification_infoReconstructionResult.X_test_metaReconstructionResult.X_train_metaReconstructionResult.classification_infoReconstructionResult.coverage_ratioReconstructionResult.feature_namesReconstructionResult.meta_feature_infoReconstructionResult.n_foldsReconstructionResult.source_modelsReconstructionResult.valid_test_maskReconstructionResult.valid_train_maskReconstructionResult.validation_resultReconstructionResult.y_testReconstructionResult.y_train
ReconstructorConfigReconstructorConfig.validate_fold_alignmentReconstructorConfig.validate_sample_coverageReconstructorConfig.log_warningsReconstructorConfig.max_missing_fold_ratioReconstructorConfig.allow_partial_sourcesReconstructorConfig.feature_name_patternReconstructorConfig.excluded_fold_idsReconstructorConfig.__post_init__()ReconstructorConfig.allow_partial_sourcesReconstructorConfig.excluded_fold_idsReconstructorConfig.feature_name_patternReconstructorConfig.log_warningsReconstructorConfig.max_missing_fold_ratioReconstructorConfig.validate_fold_alignmentReconstructorConfig.validate_sample_coverage
SourceModelReferenceSourceModelReference.model_nameSourceModelReference.model_classnameSourceModelReference.step_idxSourceModelReference.artifact_idSourceModelReference.feature_indexSourceModelReference.fold_idSourceModelReference.branch_idSourceModelReference.branch_nameSourceModelReference.branch_pathSourceModelReference.val_scoreSourceModelReference.metricSourceModelReference.artifact_idSourceModelReference.branch_idSourceModelReference.branch_nameSourceModelReference.branch_pathSourceModelReference.feature_indexSourceModelReference.fold_idSourceModelReference.from_dict()SourceModelReference.metricSourceModelReference.model_classnameSourceModelReference.model_nameSourceModelReference.step_idxSourceModelReference.to_dict()SourceModelReference.val_score
SourcePredictionErrorStackingCompatibilityStackingTaskTypeStackingTaskType.REGRESSIONStackingTaskType.BINARY_CLASSIFICATIONStackingTaskType.MULTICLASS_CLASSIFICATIONStackingTaskType.UNKNOWNStackingTaskType.BINARY_CLASSIFICATIONStackingTaskType.MULTICLASS_CLASSIFICATIONStackingTaskType.REGRESSIONStackingTaskType.UNKNOWNStackingTaskType.is_classificationStackingTaskType.n_classes
TaskTypeDetectorTrainingSetReconstructorTrainingSetReconstructor.prediction_storeTrainingSetReconstructor.source_model_namesTrainingSetReconstructor.stacking_configTrainingSetReconstructor.reconstructor_configTrainingSetReconstructor.fold_validatorTrainingSetReconstructor.reconstruct()TrainingSetReconstructor.validate_branch_compatibility()
ValidationResultValidationResult.errorsValidationResult.warningsValidationResult.is_validValidationResult.add_error()ValidationResult.add_warning()ValidationResult.errorsValidationResult.format_errors()ValidationResult.format_warnings()ValidationResult.is_validValidationResult.merge()ValidationResult.warnings
build_meta_feature_info()detect_branch_type()detect_stacking_level()get_disjoint_branch_info()is_disjoint_branch()is_stacking_compatible()stacking_config_from_dict()stacking_config_to_dict()validate_all_branches_scope()validate_multi_level_stacking()
- Submodules
- nirs4all.controllers.models.tensorflow package
Submodules
- nirs4all.controllers.models.autogluon_model module
- nirs4all.controllers.models.base_model module
BaseModelControllerBaseModelController.optuna_managerBaseModelController.identifier_generatorBaseModelController.prediction_transformerBaseModelController.prediction_assemblerBaseModelController.score_calculatorBaseModelController.index_normalizerBaseModelController.prediction_storeBaseModelController.verboseBaseModelController.execute()BaseModelController.finetune()BaseModelController.get_effective_layout()BaseModelController.get_preferred_layout()BaseModelController.get_xy()BaseModelController.launch_training()BaseModelController.load_model()BaseModelController.priorityBaseModelController.process_hyperparameters()BaseModelController.save_model()BaseModelController.supports_prediction_mode()BaseModelController.train()BaseModelController.use_multi_source()
- nirs4all.controllers.models.factory module
ModelFactoryModelFactory.build_single_model()ModelFactory.compute_input_shape()ModelFactory.detect_framework()ModelFactory.filter_params()ModelFactory.get_num_classes()ModelFactory.import_class()ModelFactory.import_object()ModelFactory.is_meta_estimator()ModelFactory.prepare_and_call()ModelFactory.reconstruct_object()
- nirs4all.controllers.models.jax_model module
- nirs4all.controllers.models.jax_wrapper module
- nirs4all.controllers.models.meta_model module
- nirs4all.controllers.models.sklearn_model module
- nirs4all.controllers.models.tensorflow_model module
- nirs4all.controllers.models.torch_model module
- nirs4all.controllers.models.utilities module
ModelControllerUtilsModelControllerUtils.DEFAULT_LOSSESModelControllerUtils.DEFAULT_METRICSModelControllerUtils.SKLEARN_SCORINGModelControllerUtils.detect_task_type()ModelControllerUtils.format_scores()ModelControllerUtils.get_best_score_metric()ModelControllerUtils.get_default_loss()ModelControllerUtils.get_default_metrics()ModelControllerUtils.get_scoring_metric()ModelControllerUtils.validate_loss_compatibility()
ModelUtils
Module contents
Model controllers module for nirs4all.
This module contains model controllers for different machine learning frameworks. All model controllers support training, fine-tuning with Optuna, and prediction modes.
Controllers follow the operator-controller pattern where: - Operators (in nirs4all.operators.models) define WHAT models to use - Controllers (here) define HOW to execute them
- class nirs4all.controllers.models.AutoGluonModelController[source]
Bases:
BaseModelControllerController for AutoGluon TabularPredictor.
This controller handles AutoGluon models with automatic model selection, ensembling, and integration with the nirs4all pipeline.
AutoGluon automatically: - Trains multiple models (LightGBM, CatBoost, XGBoost, Neural Networks, etc.) - Performs cross-validation - Creates weighted ensembles - Handles hyperparameter tuning internally
Uses lazy loading - AutoGluon is only imported when training starts.
- priority
Controller priority (5) - higher than sklearn (6) to prioritize AutoGluon when explicitly requested.
- Type:
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Any | None = None) Tuple[ExecutionContext, List[ArtifactMeta]][source]
Execute AutoGluon model controller.
Main entry point for AutoGluon model execution in the pipeline.
- Parameters:
step_info – Parsed step containing model configuration.
dataset (SpectroDataset) – Dataset containing features and targets.
context (ExecutionContext) – Pipeline execution context.
runtime_context (RuntimeContext) – Runtime context.
source (int) – Source index. Defaults to -1.
mode (str) – Execution mode. Defaults to ‘train’.
loaded_binaries – Pre-loaded model binaries for prediction.
prediction_store – Store for managing predictions.
- Returns:
- Updated context
and list of model binaries.
- Return type:
Tuple[ExecutionContext, List[ArtifactMeta]]
- get_preferred_layout() str[source]
Return the preferred data layout for AutoGluon.
- Returns:
Data layout preference, ‘2d’ for AutoGluon.
- Return type:
- load_model(filepath: str) Any[source]
Load AutoGluon model from disk.
- Parameters:
filepath (str) – Path to the saved model directory.
- Returns:
Loaded AutoGluon predictor.
- Return type:
TabularPredictor
- classmethod matches(step: Any, operator: Any, keyword: str) bool[source]
Match AutoGluon TabularPredictor configurations.
- save_model(model: Any, filepath: str) None[source]
Save AutoGluon model to disk.
AutoGluon models are saved as directories. This method moves the model’s directory to the specified filepath.
- Parameters:
model (TabularPredictor) – Trained AutoGluon predictor.
filepath (str) – Target path for saving.
- class nirs4all.controllers.models.BaseModelController[source]
Bases:
OperatorController,ABCAbstract base controller for machine learning model training and prediction.
This controller provides a unified interface for training, finetuning, and predicting with machine learning models across different frameworks (scikit-learn, TensorFlow, PyTorch). It implements cross-validation, fold averaging, hyperparameter optimization, and comprehensive prediction tracking.
- The controller delegates framework-specific operations to subclasses while handling:
Cross-validation fold management
Model identification and naming
Prediction storage and tracking
Score calculation and aggregation
Fold-averaged predictions (simple and weighted)
- optuna_manager
Manager for hyperparameter optimization.
- Type:
OptunaManager
- identifier_generator
Component for model naming.
- Type:
- prediction_transformer
Component for prediction scaling.
- Type:
- prediction_assembler
Component for assembling prediction records.
- Type:
- score_calculator
Component for calculating evaluation scores.
- Type:
- index_normalizer
Component for normalizing sample indices.
- Type:
- prediction_store
External storage for predictions.
- Type:
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[ArtifactMeta]][source]
Execute model training, finetuning, or prediction.
- This is the main entry point for model execution. It handles:
Extracting model configuration
Restoring task type in predict/explain modes
Delegating to finetune() or train() based on configuration
Managing prediction storage
- Parameters:
step_info – Parsed step containing model configuration and operator.
dataset – SpectroDataset with features and targets.
context – Execution context with step_id, partition info, etc.
runtime_context – Runtime context managing execution state.
source – Data source index (default: -1).
mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’).
loaded_binaries – Optional list of (name, bytes) tuples for prediction mode.
prediction_store – External Predictions storage instance.
- Returns:
Tuple of (updated_context, list_of_artifact_metadata).
- finetune(dataset: SpectroDataset, model_config: Dict[str, Any], X_train: Any, y_train: Any, X_test: Any, y_test: Any, folds: List | None, finetune_params: Dict[str, Any], predictions: Dict, context: ExecutionContext, runtime_context: RuntimeContext) Dict[str, Any] | List[Dict[str, Any]][source]
Optimize hyperparameters using Optuna.
Delegates to OptunaManager for Bayesian hyperparameter optimization. Returns optimized parameters that will be used in subsequent training.
- Parameters:
dataset – SpectroDataset for optimization.
model_config – Base model configuration.
X_train – Training features.
y_train – Training targets.
X_test – Test features.
y_test – Test targets.
folds – List of (train_idx, val_idx) tuples for cross-validation.
finetune_params – Optuna configuration with search space and trials.
predictions – Prediction storage dictionary.
context – Execution context.
runtime_context – Runtime context.
- Returns:
Dictionary of optimized parameters (single model) or list of dicts (per-fold).
- get_effective_layout(step_info: ParsedStep | None = None) str[source]
Get effective data layout, respecting force_layout if specified.
This method checks if the step configuration has a force_layout override. If not, it falls back to the controller’s preferred layout.
- Parameters:
step_info – ParsedStep containing potential force_layout override.
- Returns:
Data layout string to use for this step.
- get_preferred_layout() str[source]
Get preferred data layout for the framework.
- Returns:
Data layout string (‘2d’ for NumPy arrays, ‘3d’ for TensorFlow, etc.).
Note
Override in subclasses for framework-specific layouts.
- get_xy(dataset: SpectroDataset, context: ExecutionContext) Tuple[Any, Any, Any, Any, Any, Any][source]
Extract train/test splits with scaled and unscaled targets.
For classification tasks, both scaled and unscaled targets are transformed. For regression tasks, scaled targets are used for training while unscaled (numeric) targets are used for evaluation.
In prediction mode, uses all available data (partition=None) instead of splitting.
Also handles sample_partitioner branches, which restrict data to a subset of samples.
- Parameters:
dataset – SpectroDataset with partitioned data.
context – Execution context with partition and preprocessing info.
- Returns:
Tuple of (X_train, y_train, X_test, y_test, y_train_unscaled, y_test_unscaled).
- launch_training(dataset, model_config, context, runtime_context, prediction_store, X_train, y_train, X_val, y_val, X_test, y_train_unscaled, y_val_unscaled, y_test_unscaled, train_indices=None, val_indices=None, fold_idx=None, best_params=None, loaded_binaries=None, mode='train', test_sample_ids=None)[source]
Execute single model training or prediction.
This refactored method uses modular components to handle: - Model identification and naming - Model loading for predict/explain modes - Training execution - Prediction transformation - Score calculation - Prediction data assembly
- Parameters:
dataset – SpectroDataset instance
model_config – Model configuration dictionary
context – Execution context with step_id, y processing, etc.
runtime_context – Runtime context.
prediction_store – Predictions storage instance
X_train – Training data (scaled)
y_train – Training data (scaled)
X_val – Validation data (scaled)
y_val – Validation data (scaled)
X_test – Test data (scaled)
y_train_unscaled – True values (unscaled)
y_val_unscaled – True values (unscaled)
y_test_unscaled – True values (unscaled)
train_indices – Sample indices for each partition
val_indices – Sample indices for each partition
fold_idx – Optional fold index for CV
best_params – Optional hyperparameters from optimization
loaded_binaries – Optional binaries for predict/explain mode
mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’)
test_sample_ids – List of actual sample IDs for test partition (for stacking).
- Returns:
Tuple of (trained_model, model_id, val_score, model_name, prediction_data)
- load_model(filepath: str) Any[source]
Optional: Load model from framework-specific format.
Default implementation delegates to artifact_serialization.load(). Subclasses can override to use framework-specific loading.
- Parameters:
filepath – Path to load from.
- Returns:
Loaded model instance.
- process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]
Process hyperparameters before use.
Can be overridden by subclasses to structure parameters (e.g. nesting for TensorFlow).
- Parameters:
params – Flat dictionary of sampled parameters.
- Returns:
Processed dictionary of parameters.
- save_model(model: Any, filepath: str) None[source]
Optional: Save model in framework-specific format.
Default implementation delegates to artifact_serialization.persist(). Subclasses can override to use framework-specific formats: - TensorFlow: .h5 or .keras format - PyTorch: .ckpt or .pt format - sklearn: .joblib format
- Parameters:
model – Trained model to save.
filepath – Path to save (without extension, will be added by implementation).
- classmethod supports_prediction_mode() bool[source]
Check if the controller should execute during prediction mode.
- Returns:
True if the controller should execute in prediction mode, False if it should be skipped (e.g., chart controllers)
- train(dataset, model_config, context, runtime_context, prediction_store, X_train, y_train, X_test, y_test, y_train_unscaled, y_test_unscaled, folds, best_params=None, loaded_binaries=None, mode='train', train_sample_ids=None, test_sample_ids=None) List[ArtifactMeta][source]
Orchestrate model training across folds with prediction tracking.
- Manages the complete training workflow:
Iterates through cross-validation folds
Delegates to launch_training() for each fold
Creates fold-averaged predictions for regression tasks
Persists trained models as artifacts
Stores all predictions with weights
- Parameters:
dataset – SpectroDataset with features and targets.
model_config – Model configuration dictionary.
context – Execution context with step_id and preprocessing info.
runtime_context – Runtime context.
prediction_store – External Predictions storage.
X_train – Training features (all folds).
y_train – Training targets (scaled).
X_test – Test features.
y_test – Test targets (scaled).
y_train_unscaled – Training targets (unscaled for evaluation).
y_test_unscaled – Test targets (unscaled for evaluation).
folds – List of (train_idx, val_idx) tuples or empty list.
best_params – Optional hyperparameters from finetuning.
loaded_binaries – Optional model binaries for prediction mode.
mode – Execution mode (‘train’, ‘finetune’, ‘predict’, ‘explain’).
train_sample_ids – List of actual sample IDs for train partition.
test_sample_ids – List of actual sample IDs for test partition.
- Returns:
List of ArtifactMeta objects for persisted models.
- class nirs4all.controllers.models.FoldAlignmentValidator(prediction_store: Predictions, config: ReconstructorConfig | None = None)[source]
Bases:
objectValidates fold structure consistency across source models.
Ensures that all source models have compatible fold structures for proper out-of-fold reconstruction.
Checks performed: 1. All models have the same number of folds. 2. Fold indices are sequential (0, 1, 2, …, K-1). 3. No sample appears in multiple validation sets within a model. 4. Sample indices are consistent across folds.
- prediction_store
Predictions storage for accessing fold data.
- config
Reconstructor configuration.
- validate(source_model_names: List[str], context: ExecutionContext, branch_id_override: int | None = -1) ValidationResult[source]
Validate fold alignment across source models.
- Parameters:
source_model_names – List of source model names to validate.
context – Execution context with branch info.
branch_id_override – Optional branch_id override. If -1 (default), use context’s branch_id. If None, don’t filter by branch (for ALL_BRANCHES scope).
- Returns:
ValidationResult with any errors or warnings.
- class nirs4all.controllers.models.JaxModelController[source]
Bases:
BaseModelControllerController for JAX/Flax models.
Uses lazy loading pattern - JAX is only imported when training or prediction is actually performed.
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute JAX model controller.
- get_preferred_layout() str[source]
Return the preferred data layout for JAX models.
Flax Dense layers expect (batch, features). Flax Conv layers expect (batch, length, features) i.e. (N, L, C). So ‘3d_transpose’ is suitable for Conv1D.
- class nirs4all.controllers.models.MetaModelController[source]
Bases:
SklearnModelControllerController for meta-model stacking using pipeline predictions.
This controller handles MetaModel operators, constructing training features from out-of-fold predictions of previous models. It extends SklearnModelController since the meta-learner is always sklearn-compatible.
The key difference from regular model controllers is that get_xy() returns features constructed from predictions rather than the original dataset features.
- Key Behavior:
Works INDEPENDENTLY of branches (no branch awareness required for basic case)
Queries prediction_store for ALL models from previous steps
Does NOT modify execution context (unlike MergeController)
For branch-aware stacking, uses BranchScope configuration
- priority
Controller priority (5) - higher than SklearnModelController (6) to ensure MetaModel operators are handled by this controller.
- Type:
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Any | None = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute meta-model controller.
Stores MetaModel operator and prediction_store in context for use by get_xy(). Also stores source models for artifact persistence in Phase 3.
- Parameters:
step_info – Parsed step with MetaModel operator.
dataset – SpectroDataset.
context – Execution context.
runtime_context – Runtime context.
source – Data source index.
mode – Execution mode.
loaded_binaries – Pre-loaded model binaries.
prediction_store – Predictions store.
- Returns:
Tuple of (updated_context, list_of_binaries).
- get_xy(dataset: SpectroDataset, context: ExecutionContext) Tuple[Any, Any, Any, Any, Any, Any][source]
Extract train/test splits using meta-features from predictions.
Instead of using the original dataset features, this constructs features from out-of-fold predictions of source models.
- For training:
X_train: OOF predictions from source models (n_train_samples, n_source_models)
y_train: Original target values
- For test:
X_test: Aggregated source model test predictions
y_test: Original target values
- Parameters:
dataset – SpectroDataset with partitioned data.
context – Execution context with partition and branch info.
- Returns:
Tuple of (X_train, y_train, X_test, y_test, y_train_unscaled, y_test_unscaled) where X_train and X_test are meta-features from predictions.
- class nirs4all.controllers.models.PyTorchModelController[source]
Bases:
BaseModelControllerController for PyTorch models.
Uses lazy loading pattern - PyTorch is only imported when training or prediction is actually performed.
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute PyTorch model controller.
- get_preferred_layout() str[source]
Return the preferred data layout for PyTorch models.
PyTorch typically expects (samples, channels, features) for 1D convs. We use ‘3d’ which gives (samples, processings, features) -> (N, C, L).
- class nirs4all.controllers.models.ReconstructionResult(X_train_meta: ndarray, X_test_meta: ndarray, y_train: ndarray, y_test: ndarray, feature_names: List[str], source_models: List[str], valid_train_mask: ndarray, valid_test_mask: ndarray, validation_result: ValidationResult, n_folds: int, coverage_ratio: float, meta_feature_info: Any | None = None, classification_info: Any | None = None)[source]
Bases:
objectContainer for reconstructed training set data.
Holds the meta-feature matrices for training and test sets, along with metadata about the reconstruction process.
- X_train_meta
Training meta-features (n_train_samples, n_features).
- Type:
- X_test_meta
Test meta-features (n_test_samples, n_features).
- Type:
- y_train
Training targets (n_train_samples,).
- Type:
- y_test
Test targets (n_test_samples,).
- Type:
- valid_train_mask
Boolean mask of valid training samples (after coverage handling).
- Type:
- valid_test_mask
Boolean mask of valid test samples.
- Type:
- validation_result
Validation result from fold alignment.
- meta_feature_info
Optional MetaFeatureInfo for feature importance tracking.
- Type:
Any | None
- classification_info
Optional ClassificationInfo for task type metadata.
- Type:
Any | None
Example
>>> result = reconstructor.reconstruct(dataset, context) >>> X_train = result.X_train_meta[result.valid_train_mask] >>> y_train = result.y_train[result.valid_train_mask] >>> # For feature importance tracking >>> if result.meta_feature_info: ... model_importance = result.meta_feature_info.aggregate_importance_by_model( ... feature_importances ... )
- validation_result: ValidationResult
- class nirs4all.controllers.models.ReconstructorConfig(validate_fold_alignment: bool = True, validate_sample_coverage: bool = True, log_warnings: bool = True, max_missing_fold_ratio: float = 0.0, allow_partial_sources: bool = False, feature_name_pattern: str = '{model_name}_pred', excluded_fold_ids: ~typing.Set[str] = <factory>)[source]
Bases:
objectConfiguration for TrainingSetReconstructor.
Controls internal behavior of the reconstruction process, separate from the user-facing StackingConfig.
- feature_name_pattern
Pattern for generating feature column names. Supports: {model_name}, {fold_id}, {classname}, {step_idx}
- Type:
Example
>>> config = ReconstructorConfig( ... validate_fold_alignment=True, ... log_warnings=True, ... feature_name_pattern="{model_name}_pred" ... )
- class nirs4all.controllers.models.SklearnModelController[source]
Bases:
BaseModelControllerController for scikit-learn models.
This controller handles sklearn models with support for training on 2D data, cross-validation, hyperparameter tuning with Optuna, model persistence, and integration with the nirs4all pipeline.
- priority
Controller priority (6) - higher than TransformerMixin to prioritize supervised models over transformers.
- Type:
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Any | None = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute sklearn model controller with score management.
Main entry point for sklearn model execution in the pipeline. Sets the preferred data layout to ‘2d’ and delegates to parent execute method.
- Parameters:
step_info – Parsed step containing model configuration and operator.
dataset (SpectroDataset) – Dataset containing features and targets.
context (ExecutionContext) – Pipeline execution context with state info.
runtime_context (RuntimeContext) – Runtime context managing execution state.
source (int) – Source index for multi-source pipelines. Defaults to -1.
mode (str) – Execution mode (‘train’ or ‘predict’). Defaults to ‘train’.
loaded_binaries (Optional[List[Tuple[str, bytes]]]) – Pre-loaded model binaries for prediction mode. Defaults to None.
prediction_store (Optional[Any]) – Store for managing predictions. Defaults to None.
- Returns:
- Updated context and
list of model binaries (name, serialized_model) for persistence.
- Return type:
Tuple[ExecutionContext, List[Tuple[str, bytes]]]
Note
Automatically sets context[‘layout’] = ‘2d’ for sklearn compatibility
Inherits full training, evaluation, and prediction logic from BaseModelController
Respects force_layout if specified in step configuration
- get_preferred_layout() str[source]
Return the preferred data layout for sklearn models.
- Returns:
- Data layout preference, always ‘2d’ for sklearn models which
expect (n_samples, n_features) input format.
- Return type:
- classmethod matches(step: Any, operator: Any, keyword: str) bool[source]
Match sklearn estimators and model dictionaries with sklearn models.
Prioritizes supervised models (regressors and classifiers) over transformers by checking for predict methods and using sklearn’s is_regressor/is_classifier.
- Parameters:
step (Any) – Pipeline step to check, can be a dict with ‘model’ key or BaseEstimator instance.
operator (Any) – Optional operator object to check if it’s a BaseEstimator.
keyword (str) – Pipeline keyword (unused in this implementation).
- Returns:
- True if the step matches a sklearn estimator (regressor, classifier,
or has predict method), False otherwise.
- Return type:
- class nirs4all.controllers.models.TensorFlowModelController[source]
Bases:
BaseModelControllerController for TensorFlow/Keras models.
This controller manages the complete lifecycle of TensorFlow/Keras models including: - Model instantiation from various configuration formats - Data preparation with proper tensor formatting (2D/3D) - Model compilation with task-appropriate loss functions and metrics - Training with callbacks (early stopping, model checkpointing) - Hyperparameter tuning via Optuna integration - Model evaluation and prediction - Binary serialization for model persistence
The controller automatically detects TensorFlow models and functions decorated with @framework(‘tensorflow’). It uses lazy loading to avoid importing TensorFlow until actually needed.
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, bytes]] | None = None, prediction_store: Predictions = None) Tuple[ExecutionContext, List[Tuple[str, bytes]]][source]
Execute TensorFlow model training, finetuning, or prediction.
Sets the preferred data layout to ‘3d_transpose’ for TensorFlow Conv1D models, then delegates to the base class execute method.
- Parameters:
step_info – Parsed step containing model configuration and operator.
dataset – SpectroDataset with features, targets, and fold information.
context – Execution context with step_id, processing history, partition info.
runtime_context – Runtime context managing execution state.
source – Data source index (default: -1 for primary source).
mode – Execution mode - ‘train’, ‘finetune’, ‘predict’, or ‘explain’.
loaded_binaries – Optional list of (name, bytes) tuples for prediction mode, containing serialized model and preprocessing artifacts.
prediction_store – External Predictions storage instance for managing prediction results across pipeline steps.
- Returns:
updated_context: Context dict with added model information
artifact_metadata: List of serialized binary artifacts for persistence
- Return type:
Tuple of (updated_context, list_of_artifact_metadata) where
- Raises:
ImportError – If TensorFlow is not installed.
- get_preferred_layout() str[source]
Return the preferred data layout for TensorFlow models.
TensorFlow Conv1D expects input shape (features, channels) where: - features = number of wavelengths/spectral points (timesteps for convolution) - channels = number of preprocessing methods
The ‘3d_transpose’ layout returns (samples, features, processings) which is correct for Conv1D.
- classmethod matches(step: Any, operator: Any, keyword: str) bool[source]
Determine if this controller should handle the given step.
Matches TensorFlow/Keras models, functions decorated with @framework(‘tensorflow’), and serialized model configurations containing TensorFlow components.
- Parameters:
step – Pipeline step configuration (dict, model instance, or function).
operator – Optional operator instance extracted from step.
keyword – Optional keyword identifier for the step.
- Returns:
True if this controller should handle the step, False otherwise. Returns False immediately if TensorFlow is not installed.
- process_hyperparameters(params: Dict[str, Any]) Dict[str, Any][source]
Process hyperparameters for TensorFlow model tuning.
Supports TensorFlow-specific parameter organization: - Parameters prefixed with ‘compile_’ are grouped under ‘compile’ key
(e.g., ‘compile_learning_rate’ → compile[‘learning_rate’])
Parameters prefixed with ‘fit_’ are grouped under ‘fit’ key (e.g., ‘fit_batch_size’ → fit[‘batch_size’])
Other parameters are treated as model architecture parameters
- Parameters:
params – Dictionary of sampled parameters.
- Returns:
Dictionary of processed hyperparameters with proper nesting for TensorFlow compilation and fitting.
- class nirs4all.controllers.models.TrainingSetReconstructor(prediction_store: Predictions, source_model_names: List[str], stacking_config: StackingConfig | None = None, reconstructor_config: ReconstructorConfig | None = None, source_model_branch_map: Dict[str, int | None] | None = None)[source]
Bases:
objectReconstructs meta-model training set from out-of-fold predictions.
This is the core class for Phase 2 of the meta-model stacking implementation. It handles the critical task of collecting OOF predictions from source models and constructing feature matrices for the meta-learner.
The fundamental invariant is: No sample sees predictions from a model trained on that sample. This prevents data leakage.
- prediction_store
Predictions storage for accessing source predictions.
- source_model_names
List of source model names to use.
- stacking_config
Configuration for coverage and aggregation strategies.
- reconstructor_config
Internal configuration for reconstruction.
- fold_validator
Validator for fold alignment.
Example
>>> reconstructor = TrainingSetReconstructor( ... prediction_store=predictions, ... source_model_names=["PLS", "RF", "XGB"], ... stacking_config=StackingConfig( ... coverage_strategy=CoverageStrategy.DROP_INCOMPLETE, ... test_aggregation=TestAggregation.MEAN ... ) ... ) >>> result = reconstructor.reconstruct(dataset, context) >>> print(f"Coverage: {result.coverage_ratio:.1%}") >>> print(f"Features: {result.feature_names}")
- reconstruct(dataset: SpectroDataset, context: ExecutionContext, y_train: ndarray | None = None, y_test: ndarray | None = None, use_proba: bool = False) ReconstructionResult[source]
Reconstruct meta-model training and test sets from predictions.
Collects out-of-fold predictions for training samples and aggregated predictions for test samples.
Phase 5 Enhancement: Supports classification tasks with probability features for binary and multiclass classification.
- Parameters:
dataset – SpectroDataset for sample indices.
context – Execution context with partition and branch info.
y_train – Optional pre-computed training targets.
y_test – Optional pre-computed test targets.
use_proba – If True, use probability predictions for classification.
- Returns:
ReconstructionResult containing meta-feature matrices and metadata.
- Raises:
ValueError – If no source models found or critical validation fails.
- validate_branch_compatibility(context: ExecutionContext) ValidationResult[source]
Validate branch compatibility for stacking.
Checks that the current branch context is compatible with stacking based on the configured BranchScope.
- Parameters:
context – Execution context with branch info.
- Returns:
ValidationResult with any errors or warnings.
- class nirs4all.controllers.models.ValidationResult(errors: ~typing.List[~nirs4all.controllers.models.stacking.reconstructor.ValidationError] = <factory>, warnings: ~typing.List[~nirs4all.controllers.models.stacking.reconstructor.ValidationWarning] = <factory>)[source]
Bases:
objectContainer for validation errors and warnings.
Accumulates validation issues during fold alignment and coverage checks.
- errors
List of validation errors (critical issues).
- warnings
List of validation warnings (non-critical issues).
- is_valid
True if no errors (warnings are allowed).
Example
>>> result = ValidationResult() >>> result.add_error("FOLD_MISMATCH", "Folds don't align") >>> result.add_warning("PARTIAL_COVERAGE", "80% coverage") >>> if not result.is_valid: ... raise ValueError(result.format_errors())
- add_error(code: str, message: str, details: Dict[str, Any] | None = None) None[source]
Add a validation error.
- add_warning(code: str, message: str, details: Dict[str, Any] | None = None) None[source]
Add a validation warning.
- errors: List[ValidationError]
- merge(other: ValidationResult) None[source]
Merge another validation result into this one.
- warnings: List[ValidationWarning]