Source code for nirs4all.operators.models.selection

"""Source model selection strategies for meta-model stacking.

This module provides flexible strategies for selecting which models
to use as sources in a stacking ensemble.

Available selectors:
    - AllPreviousModelsSelector: Use all models from previous steps (default)
    - ExplicitModelSelector: Use explicitly named models
    - TopKByMetricSelector: Use top K models by validation metric
    - DiversitySelector: Select diverse models by class type

Example:
    >>> from nirs4all.operators.models.selection import (
    ...     AllPreviousModelsSelector,
    ...     TopKByMetricSelector,
    ...     SelectorFactory
    ... )
    >>>
    >>> # Default: all previous models
    >>> selector = AllPreviousModelsSelector()
    >>>
    >>> # Top 3 by RMSE
    >>> selector = TopKByMetricSelector(k=3, metric="rmse")
    >>>
    >>> # Using factory
    >>> selector = SelectorFactory.create("top_k", k=5, metric="r2")
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from nirs4all.data.predictions import Predictions
    from nirs4all.pipeline.config.context import ExecutionContext


[docs] @dataclass class ModelCandidate: """Information about a candidate source model. Contains metadata and optionally predictions for a model that may be selected as a source for stacking. Attributes: model_name: Name of the model. model_classname: Class name of the model (e.g., "PLSRegression"). step_idx: Pipeline step index where the model was trained. fold_id: Fold identifier (or "avg"/"w_avg" for averaged models). branch_id: Branch identifier if in a branched pipeline. branch_name: Human-readable branch name. val_score: Validation score for the model. metric: Metric used for scoring. predictions: Optional dictionary with predictions data. """ model_name: str model_classname: str step_idx: int fold_id: Optional[str] = None branch_id: Optional[int] = None branch_name: Optional[str] = None val_score: Optional[float] = None metric: Optional[str] = None predictions: Optional[Dict[str, np.ndarray]] = None
[docs] class SourceModelSelector(ABC): """Abstract base class for source model selection strategies. Defines the interface for selecting which models to include as sources in a stacking ensemble. Subclasses must implement the select() method to define their selection logic. Example: >>> class CustomSelector(SourceModelSelector): ... def select(self, candidates, context, prediction_store): ... # Custom selection logic ... return [c for c in candidates if c.val_score > 0.9] """
[docs] @abstractmethod def select( self, candidates: List[ModelCandidate], context: 'ExecutionContext', prediction_store: 'Predictions' ) -> List[ModelCandidate]: """Select source models from candidates. Args: candidates: List of candidate models to select from. context: Execution context with current step and branch info. prediction_store: Predictions store for accessing model data. Returns: List of selected ModelCandidate objects in the order they should be used as features (determines column order in meta-features). """ pass
[docs] def validate( self, selected: List[ModelCandidate], context: 'ExecutionContext' ) -> None: """Validate the selection (optional override). Can raise ValueError if selection is invalid for the context. Args: selected: List of selected model candidates. context: Execution context for validation. Raises: ValueError: If selection is invalid. """ if not selected: raise ValueError( "No source models selected for stacking. " "Ensure previous models exist in the pipeline." )
[docs] class AllPreviousModelsSelector(SourceModelSelector): """Select all models from previous steps in current branch. This is the default selector that includes all models trained before the meta-model step within the same branch context. Attributes: include_averaged: If True, include fold-averaged models. Default False (uses individual fold models). exclude_classnames: Set of model class names to exclude. Example: >>> selector = AllPreviousModelsSelector(include_averaged=True) """ def __init__( self, include_averaged: bool = False, exclude_classnames: Optional[Set[str]] = None ): """Initialize selector. Args: include_averaged: If True, include 'avg' and 'w_avg' fold entries. Default False to use individual fold predictions for OOF. exclude_classnames: Optional set of model class names to exclude from selection. """ self.include_averaged = include_averaged self.exclude_classnames = exclude_classnames or set()
[docs] def select( self, candidates: List[ModelCandidate], context: 'ExecutionContext', prediction_store: 'Predictions' ) -> List[ModelCandidate]: """Select all previous models in current branch. Args: candidates: List of candidate models. context: Execution context with step and branch info. prediction_store: Predictions store (unused in this selector). Returns: Filtered list of candidates ordered by step index. """ current_step = context.state.step_number current_branch_id = getattr(context.selector, 'branch_id', None) selected = [] for candidate in candidates: # Skip models from current or later steps if candidate.step_idx >= current_step: continue # Filter by branch if in a branch context if current_branch_id is not None: # Include models from same branch OR from before the branch (branch_id is None) if candidate.branch_id is not None and candidate.branch_id != current_branch_id: continue # Skip averaged models if not requested if not self.include_averaged: if candidate.fold_id in ('avg', 'w_avg'): continue # Skip excluded class names if candidate.model_classname in self.exclude_classnames: continue selected.append(candidate) # Sort by step index for consistent feature ordering selected.sort(key=lambda c: (c.step_idx, c.model_name, c.fold_id or '')) return selected
[docs] class ExplicitModelSelector(SourceModelSelector): """Select explicitly named models. Uses a predefined list of model names to select sources. Model names must match exactly (case-sensitive). Attributes: model_names: List of model names to select. strict: If True, raise error if any named model is not found. Example: >>> selector = ExplicitModelSelector( ... model_names=["PLS", "RandomForest", "XGBoost"], ... strict=True ... ) """ def __init__( self, model_names: List[str], strict: bool = True ): """Initialize selector with explicit model names. Args: model_names: List of model names to select. strict: If True, raise ValueError if any model is not found among candidates. Default True. """ if not model_names: raise ValueError("model_names cannot be empty") self.model_names = model_names self.strict = strict
[docs] def select( self, candidates: List[ModelCandidate], context: 'ExecutionContext', prediction_store: 'Predictions' ) -> List[ModelCandidate]: """Select models matching the specified names. Args: candidates: List of candidate models. context: Execution context. prediction_store: Predictions store (unused). Returns: List of candidates matching specified names, in the order specified by model_names. Raises: ValueError: If strict=True and any model name is not found. """ current_step = context.state.step_number current_branch_id = getattr(context.selector, 'branch_id', None) # Build a map of model_name -> candidates name_to_candidates: Dict[str, List[ModelCandidate]] = {} for candidate in candidates: # Skip models from current or later steps if candidate.step_idx >= current_step: continue # Filter by branch if current_branch_id is not None: if candidate.branch_id != current_branch_id: continue if candidate.model_name not in name_to_candidates: name_to_candidates[candidate.model_name] = [] name_to_candidates[candidate.model_name].append(candidate) # Select in order specified selected = [] missing = [] for name in self.model_names: if name in name_to_candidates: # Add all candidates with this name (all folds) candidates_for_name = name_to_candidates[name] # Sort by fold for consistency candidates_for_name.sort(key=lambda c: c.fold_id or '') selected.extend(candidates_for_name) else: missing.append(name) if self.strict and missing: raise ValueError( f"Source models not found in prediction store: {missing}. " f"Available models: {list(name_to_candidates.keys())}" ) return selected
[docs] class TopKByMetricSelector(SourceModelSelector): """Select top K models by a validation metric. Ranks models by their validation score and selects the top K performers. Attributes: k: Number of top models to select. metric: Metric to rank by (e.g., "rmse", "r2", "accuracy"). ascending: Sort direction. If None, inferred from metric. per_class: If True, select top K per model class (for diversity). Example: >>> selector = TopKByMetricSelector(k=3, metric="rmse", ascending=True) """ # Metrics where lower is better _LOWER_IS_BETTER = {'rmse', 'mse', 'mae', 'mape', 'loss'} def __init__( self, k: int, metric: str = "val_score", ascending: Optional[bool] = None, per_class: bool = False ): """Initialize top-K selector. Args: k: Number of top models to select. metric: Metric name to rank by. If "val_score", uses the stored validation score. ascending: Sort order. If True, lower scores rank higher. If None, inferred from metric name. per_class: If True, select top K per unique model class name. This ensures diversity in the ensemble. """ if k < 1: raise ValueError(f"k must be at least 1, got {k}") self.k = k self.metric = metric self.ascending = ascending self.per_class = per_class def _infer_ascending(self, metric: str) -> bool: """Infer sort direction from metric name. Args: metric: Metric name. Returns: True if lower is better, False if higher is better. """ metric_lower = metric.lower() return any(m in metric_lower for m in self._LOWER_IS_BETTER)
[docs] def select( self, candidates: List[ModelCandidate], context: 'ExecutionContext', prediction_store: 'Predictions' ) -> List[ModelCandidate]: """Select top K models by metric. Args: candidates: List of candidate models. context: Execution context. prediction_store: Predictions store (unused). Returns: Top K models sorted by metric. """ current_step = context.state.step_number current_branch_id = getattr(context.selector, 'branch_id', None) # Filter candidates valid_candidates = [] for candidate in candidates: if candidate.step_idx >= current_step: continue if current_branch_id is not None and candidate.branch_id != current_branch_id: continue # Skip averaged models for individual fold handling if candidate.fold_id in ('avg', 'w_avg'): continue if candidate.val_score is not None: valid_candidates.append(candidate) # Determine sort direction ascending = self.ascending if ascending is None: ascending = self._infer_ascending(self.metric) if self.per_class: # Select top K per class class_groups: Dict[str, List[ModelCandidate]] = {} for candidate in valid_candidates: classname = candidate.model_classname if classname not in class_groups: class_groups[classname] = [] class_groups[classname].append(candidate) selected = [] for classname, group in class_groups.items(): # Sort group by score group.sort(key=lambda c: c.val_score or float('inf'), reverse=not ascending) selected.extend(group[:self.k]) # Final sort by score selected.sort(key=lambda c: c.val_score or float('inf'), reverse=not ascending) else: # Sort all candidates by score valid_candidates.sort( key=lambda c: c.val_score or float('inf'), reverse=not ascending ) selected = valid_candidates[:self.k] return selected
[docs] class DiversitySelector(SourceModelSelector): """Select diverse models by class type to maximize ensemble diversity. Ensures the stacking ensemble includes different types of models rather than multiple similar models. Attributes: max_per_class: Maximum models per class type. preferred_classes: Optional list of preferred class names. Example: >>> selector = DiversitySelector( ... max_per_class=2, ... preferred_classes=["PLSRegression", "RandomForestRegressor"] ... ) """ def __init__( self, max_per_class: int = 1, preferred_classes: Optional[List[str]] = None ): """Initialize diversity selector. Args: max_per_class: Maximum number of models per class type. preferred_classes: Optional list of preferred model class names. These are selected first if available. """ if max_per_class < 1: raise ValueError(f"max_per_class must be at least 1, got {max_per_class}") self.max_per_class = max_per_class self.preferred_classes = preferred_classes or []
[docs] def select( self, candidates: List[ModelCandidate], context: 'ExecutionContext', prediction_store: 'Predictions' ) -> List[ModelCandidate]: """Select diverse models by class type. Args: candidates: List of candidate models. context: Execution context. prediction_store: Predictions store (unused). Returns: List of diverse models with at most max_per_class per type. """ current_step = context.state.step_number current_branch_id = getattr(context.selector, 'branch_id', None) # Filter and group by class class_groups: Dict[str, List[ModelCandidate]] = {} for candidate in candidates: if candidate.step_idx >= current_step: continue if current_branch_id is not None and candidate.branch_id != current_branch_id: continue if candidate.fold_id in ('avg', 'w_avg'): continue classname = candidate.model_classname if classname not in class_groups: class_groups[classname] = [] class_groups[classname].append(candidate) selected = [] # First, add preferred classes for classname in self.preferred_classes: if classname in class_groups: group = class_groups.pop(classname) # Sort by score (best first) and take up to max_per_class group.sort(key=lambda c: c.val_score or float('inf')) selected.extend(group[:self.max_per_class]) # Then add remaining classes for classname, group in sorted(class_groups.items()): group.sort(key=lambda c: c.val_score or float('inf')) selected.extend(group[:self.max_per_class]) # Sort final selection by step index for consistent feature ordering selected.sort(key=lambda c: (c.step_idx, c.model_name)) return selected
[docs] class SelectorFactory: """Factory for creating source model selectors. Provides a convenient way to instantiate selectors by name. Example: >>> selector = SelectorFactory.create("all") >>> selector = SelectorFactory.create("explicit", model_names=["PLS", "RF"]) >>> selector = SelectorFactory.create("top_k", k=5, metric="rmse") """ _REGISTRY: Dict[str, type] = { 'all': AllPreviousModelsSelector, 'all_previous': AllPreviousModelsSelector, 'explicit': ExplicitModelSelector, 'top_k': TopKByMetricSelector, 'topk': TopKByMetricSelector, 'diversity': DiversitySelector, 'diverse': DiversitySelector, }
[docs] @classmethod def create(cls, selector_type: str, **kwargs) -> SourceModelSelector: """Create a selector by type name. Args: selector_type: Type name (e.g., "all", "explicit", "top_k", "diversity"). **kwargs: Arguments passed to the selector constructor. Returns: SourceModelSelector instance. Raises: ValueError: If selector_type is not recognized. """ selector_type_lower = selector_type.lower() if selector_type_lower not in cls._REGISTRY: available = list(cls._REGISTRY.keys()) raise ValueError( f"Unknown selector type: {selector_type}. " f"Available: {available}" ) selector_class = cls._REGISTRY[selector_type_lower] return selector_class(**kwargs)
[docs] @classmethod def register(cls, name: str, selector_class: type) -> None: """Register a custom selector type. Args: name: Name to register under. selector_class: Selector class (must inherit from SourceModelSelector). Raises: TypeError: If selector_class doesn't inherit from SourceModelSelector. """ if not issubclass(selector_class, SourceModelSelector): raise TypeError( f"Selector class must inherit from SourceModelSelector, " f"got {selector_class.__name__}" ) cls._REGISTRY[name.lower()] = selector_class