nirs4all.controllers.shared.model_selector module
ModelSelector - Utility class for selecting models based on validation metrics.
This module provides model ranking and selection strategies for stacking and branch merging operations. It handles model selection based on validation metrics with support for various strategies.
- Phase 2 Implementation (Stacking Restoration):
Extracted from MergeController to provide shared model selection logic for both MergeController and MetaModelController.
- Selection Strategies:
ALL: Use all available models
BEST: Use single best model by metric
TOP_K: Use top K models by metric
EXPLICIT: Use explicitly named models
REGEX: Use models matching pattern (future)
THRESHOLD: Use models above/below metric threshold (future)
Example
>>> from nirs4all.controllers.shared import ModelSelector
>>> from nirs4all.operators.data.merge import BranchPredictionConfig
>>>
>>> selector = ModelSelector(prediction_store, context)
>>> config = BranchPredictionConfig(branch=0, select="best", metric="rmse")
>>> selected = selector.select_models(["PLS", "RF", "XGB"], config, branch_id=0)
>>> print(selected) # ["PLS"] (assuming PLS has best RMSE)
- class nirs4all.controllers.shared.model_selector.ModelSelector(prediction_store: Predictions, context: ExecutionContext)[source]
Bases:
objectUtility class for selecting models based on validation metrics.
Handles model ranking and selection strategies (all, best, top_k, explicit) for per-branch prediction collection and stacking operations.
This class is shared between MergeController and MetaModelController to avoid code duplication.
- prediction_store
Prediction storage instance.
- context
Execution context.
- LOWER_IS_BETTER_METRICS
Set of metrics where lower values are better.
- LOWER_IS_BETTER_METRICS = {'log_loss', 'mae', 'mape', 'mse', 'nmae', 'nmse', 'nrmse', 'rmse'}
- get_model_scores(model_names: List[str], metric: str, branch_id: int) → Dict[str, float][source]
Get validation scores for multiple models.
Used for weighted aggregation.
- Parameters:
model_names – List of model names.
metric – Metric name.
branch_id – Branch identifier.
- Returns:
Dictionary mapping model name to score.
- select_models(available_models: List[str], config: BranchPredictionConfig, branch_id: int) → List[str][source]
Select models from available models based on config.
- Parameters:
available_models – List of available model names in the branch.
config – Per-branch prediction configuration.
branch_id – Branch identifier.
- Returns:
List of selected model names.
- Raises:
ValueError – If explicit model selection references unknown models.
- select_models_global(available_models: List[str], selection: Any, metric: str | None = None) → List[str][source]
Select models globally (without branch context).
This is used by MetaModelController for pipelines without branches.
- Parameters:
available_models – List of available model names.
selection – Selection configuration: - “all”: Use all models - “best”: Use best model - {“top_k”: N}: Use top N models - [“model1”, “model2”]: Explicit list
metric – Optional metric for ranking.
- Returns:
List of selected model names.