nirs4all.controllers.shared.model_selector module

ModelSelector - Utility class for selecting models based on validation metrics.

This module provides model ranking and selection strategies for stacking and branch merging operations. It handles model selection based on validation metrics with support for various strategies.

Phase 2 Implementation (Stacking Restoration):

Extracted from MergeController to provide shared model selection logic for both MergeController and MetaModelController.

Selection Strategies:
  • ALL: Use all available models

  • BEST: Use single best model by metric

  • TOP_K: Use top K models by metric

  • EXPLICIT: Use explicitly named models

  • REGEX: Use models matching pattern (future)

  • THRESHOLD: Use models above/below metric threshold (future)

Example

>>> from nirs4all.controllers.shared import ModelSelector
>>> from nirs4all.operators.data.merge import BranchPredictionConfig
>>>
>>> selector = ModelSelector(prediction_store, context)
>>> config = BranchPredictionConfig(branch=0, select="best", metric="rmse")
>>> selected = selector.select_models(["PLS", "RF", "XGB"], config, branch_id=0)
>>> print(selected)  # ["PLS"] (assuming PLS has best RMSE)
class nirs4all.controllers.shared.model_selector.ModelSelector(prediction_store: Predictions, context: ExecutionContext)[source]

Bases: object

Utility class for selecting models based on validation metrics.

Handles model ranking and selection strategies (all, best, top_k, explicit) for per-branch prediction collection and stacking operations.

This class is shared between MergeController and MetaModelController to avoid code duplication.

prediction_store

Prediction storage instance.

context

Execution context.

LOWER_IS_BETTER_METRICS

Set of metrics where lower values are better.

LOWER_IS_BETTER_METRICS = {'log_loss', 'mae', 'mape', 'mse', 'nmae', 'nmse', 'nrmse', 'rmse'}
get_model_scores(model_names: List[str], metric: str, branch_id: int) Dict[str, float][source]

Get validation scores for multiple models.

Used for weighted aggregation.

Parameters:
  • model_names – List of model names.

  • metric – Metric name.

  • branch_id – Branch identifier.

Returns:

Dictionary mapping model name to score.

select_models(available_models: List[str], config: BranchPredictionConfig, branch_id: int) List[str][source]

Select models from available models based on config.

Parameters:
  • available_models – List of available model names in the branch.

  • config – Per-branch prediction configuration.

  • branch_id – Branch identifier.

Returns:

List of selected model names.

Raises:

ValueError – If explicit model selection references unknown models.

select_models_global(available_models: List[str], selection: Any, metric: str | None = None) List[str][source]

Select models globally (without branch context).

This is used by MetaModelController for pipelines without branches.

Parameters:
  • available_models – List of available model names.

  • selection – Selection configuration: - “all”: Use all models - “best”: Use best model - {“top_k”: N}: Use top N models - [“model1”, “model2”]: Explicit list

  • metric – Optional metric for ranking.

Returns:

List of selected model names.