Source code for nirs4all.controllers.models.stacking.multilevel

"""
Multi-Level Stacking Validator (Phase 7).

This module provides validation and level detection for multi-level stacking,
where meta-models can use predictions from other meta-models as sources.

The validator ensures:
1. No circular dependencies exist in the stacking hierarchy
2. Stacking levels don't exceed configured maximum
3. Level detection works correctly for AUTO mode
4. Source models from appropriate levels are selected

Stacking Hierarchy:
    Level 0: Base models (PLS, RF, XGBoost, Neural Networks, etc.)
    Level 1: First meta-models (stack on Level 0 only)
    Level 2: Second meta-models (stack on Level 0 + Level 1)
    Level 3: Third meta-models (stack on Level 0 + Level 1 + Level 2)

Example:
    >>> validator = MultiLevelValidator(prediction_store)
    >>> result = validator.validate_sources(
    ...     meta_model_name="FinalMeta",
    ...     source_candidates=candidates,
    ...     context=context
    ... )
    >>> if not result.is_valid:
    ...     raise CircularDependencyError(...)
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import warnings

from .exceptions import (
    CircularDependencyError,
    MaxStackingLevelExceededError,
    InconsistentLevelError,
)

if TYPE_CHECKING:
    from nirs4all.data.predictions import Predictions
    from nirs4all.pipeline.config.context import ExecutionContext
    from nirs4all.operators.models.selection import ModelCandidate


[docs] @dataclass class ModelLevelInfo: """Information about a model's stacking level. Attributes: model_name: Name of the model. level: Stacking level (0 for base models, 1+ for meta-models). is_meta_model: Whether this is a meta-model. source_models: List of source model names (for meta-models). step_idx: Pipeline step index. """ model_name: str level: int is_meta_model: bool source_models: List[str] = field(default_factory=list) step_idx: int = 0
[docs] @dataclass class LevelValidationResult: """Result of multi-level stacking validation. Attributes: is_valid: Whether the validation passed. detected_level: The detected stacking level for the meta-model. source_levels: Dict mapping source model names to their levels. circular_dependencies: List of detected circular dependencies. warnings: List of warning messages. errors: List of error messages. """ is_valid: bool = True detected_level: int = 1 source_levels: Dict[str, int] = field(default_factory=dict) circular_dependencies: List[List[str]] = field(default_factory=list) warnings: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
[docs] def add_warning(self, message: str) -> None: """Add a warning message.""" self.warnings.append(message)
[docs] def add_error(self, message: str) -> None: """Add an error and mark as invalid.""" self.errors.append(message) self.is_valid = False
[docs] class MultiLevelValidator: """Validates multi-level stacking configurations. Ensures that stacking hierarchies are valid, detects circular dependencies, and computes appropriate stacking levels for meta-models. Attributes: prediction_store: Predictions storage for analyzing model metadata. max_level: Maximum allowed stacking level. log_warnings: Whether to emit Python warnings. """ # Meta-model class name patterns to detect META_MODEL_PATTERNS = {'MetaModel', 'StackingRegressor', 'StackingClassifier'} def __init__( self, prediction_store: 'Predictions', max_level: int = 3, log_warnings: bool = True ): """Initialize multi-level validator. Args: prediction_store: Predictions storage. max_level: Maximum allowed stacking level (default 3). log_warnings: Whether to emit Python warnings. """ self.prediction_store = prediction_store self.max_level = max_level self.log_warnings = log_warnings # Cache for model level info self._level_cache: Dict[str, ModelLevelInfo] = {}
[docs] def validate_sources( self, meta_model_name: str, source_candidates: List['ModelCandidate'], context: 'ExecutionContext', allow_meta_sources: bool = True ) -> LevelValidationResult: """Validate source models for a meta-model. Checks for circular dependencies and computes the appropriate stacking level based on source model levels. Args: meta_model_name: Name of the meta-model being validated. source_candidates: List of candidate source models. context: Execution context. allow_meta_sources: Whether to allow other meta-models as sources. Returns: LevelValidationResult with validation status and detected level. """ result = LevelValidationResult() # Get unique source model names source_names = list(dict.fromkeys(c.model_name for c in source_candidates)) if not source_names: result.add_warning("No source models provided for multi-level validation") return result # Build level info for all sources for name in source_names: level_info = self._get_model_level_info(name, context) result.source_levels[name] = level_info.level # Check if meta-model sources are allowed if level_info.is_meta_model and not allow_meta_sources: result.add_error( f"Source model '{name}' is a meta-model but allow_meta_sources=False" ) continue # Check for circular dependencies for name in source_names: cycle = self._detect_circular_dependency( meta_model_name, name, context, visited=set() ) if cycle: result.circular_dependencies.append(cycle) result.add_error( f"Circular dependency detected: {' -> '.join(cycle)}" ) if not result.is_valid: return result # Compute detected level max_source_level = max(result.source_levels.values()) if result.source_levels else 0 result.detected_level = max_source_level + 1 # Check against max level if result.detected_level > self.max_level: result.add_error( f"Detected level {result.detected_level} exceeds maximum {self.max_level}" ) return result # Add info about detected level if max_source_level > 0: meta_sources = [n for n, l in result.source_levels.items() if l > 0] result.add_warning( f"Multi-level stacking detected (level {result.detected_level}). " f"Using meta-model sources: {meta_sources}" ) # Emit warnings if configured if self.log_warnings: for warning in result.warnings: warnings.warn(warning) return result
[docs] def detect_level( self, source_candidates: List['ModelCandidate'], context: 'ExecutionContext' ) -> int: """Detect the appropriate stacking level based on source models. Args: source_candidates: List of candidate source models. context: Execution context. Returns: Detected stacking level (1 if no meta-model sources, 2+ otherwise). """ if not source_candidates: return 1 max_level = 0 for candidate in source_candidates: level_info = self._get_model_level_info(candidate.model_name, context) max_level = max(max_level, level_info.level) return max_level + 1
[docs] def filter_by_level( self, candidates: List['ModelCandidate'], context: 'ExecutionContext', max_source_level: Optional[int] = None, exclude_meta_models: bool = False ) -> List['ModelCandidate']: """Filter source candidates by stacking level. Args: candidates: List of candidate source models. context: Execution context. max_source_level: Maximum allowed source level (None = no limit). exclude_meta_models: If True, exclude all meta-models from sources. Returns: Filtered list of candidates. """ filtered = [] for candidate in candidates: level_info = self._get_model_level_info(candidate.model_name, context) # Check meta-model exclusion if exclude_meta_models and level_info.is_meta_model: continue # Check level limit if max_source_level is not None and level_info.level > max_source_level: continue filtered.append(candidate) return filtered
def _get_model_level_info( self, model_name: str, context: 'ExecutionContext' ) -> ModelLevelInfo: """Get or compute level info for a model. Args: model_name: Name of the model. context: Execution context. Returns: ModelLevelInfo with level and metadata. """ # Check cache if model_name in self._level_cache: return self._level_cache[model_name] # Get model predictions to determine type preds = self.prediction_store.filter_predictions( model_name=model_name, load_arrays=False ) if not preds: # Unknown model - assume base level info = ModelLevelInfo( model_name=model_name, level=0, is_meta_model=False, source_models=[], step_idx=0 ) self._level_cache[model_name] = info return info # Check if this is a meta-model based on class name first_pred = preds[0] classname = first_pred.get('model_classname', '') step_idx = first_pred.get('step_idx', 0) is_meta = any( pattern in classname for pattern in self.META_MODEL_PATTERNS ) # Also check model name for MetaModel prefix if 'MetaModel' in model_name or 'Meta_' in model_name: is_meta = True if not is_meta: # Base model - level 0 info = ModelLevelInfo( model_name=model_name, level=0, is_meta_model=False, source_models=[], step_idx=step_idx ) self._level_cache[model_name] = info return info # Meta-model - need to find its source models source_models = self._find_source_models(model_name, step_idx, context) # Compute level based on source model levels if not source_models: level = 1 # Meta-model with unknown sources else: max_source_level = 0 for source_name in source_models: source_info = self._get_model_level_info(source_name, context) max_source_level = max(max_source_level, source_info.level) level = max_source_level + 1 info = ModelLevelInfo( model_name=model_name, level=level, is_meta_model=True, source_models=source_models, step_idx=step_idx ) self._level_cache[model_name] = info return info def _find_source_models( self, meta_model_name: str, meta_step_idx: int, context: 'ExecutionContext' ) -> List[str]: """Find source models for a meta-model. Attempts to identify which models were used as sources for a meta-model by looking at predictions from earlier steps in the same branch. Args: meta_model_name: Name of the meta-model. meta_step_idx: Step index of the meta-model. context: Execution context. Returns: List of source model names. """ branch_id = getattr(context.selector, 'branch_id', None) # Get all models from earlier steps all_preds = self.prediction_store.filter_predictions( load_arrays=False, branch_id=branch_id ) # Filter to earlier steps earlier_preds = [ p for p in all_preds if p.get('step_idx', 0) < meta_step_idx ] # Get unique model names source_models = list(dict.fromkeys( p.get('model_name', '') for p in earlier_preds if p.get('model_name') )) return source_models def _detect_circular_dependency( self, meta_model_name: str, source_name: str, context: 'ExecutionContext', visited: Set[str], path: Optional[List[str]] = None ) -> Optional[List[str]]: """Detect circular dependencies in the stacking hierarchy. Uses DFS to detect cycles in the dependency graph. Args: meta_model_name: Name of the meta-model being built. source_name: Name of the source model being checked. context: Execution context. visited: Set of already visited model names. path: Current path in the dependency graph. Returns: List of model names forming a cycle, or None if no cycle. """ if path is None: path = [meta_model_name] # Direct self-reference if source_name == meta_model_name: return path + [source_name] # Already visited in this path if source_name in visited: return None # Not a cycle for us # Check if source is in current path (cycle) if source_name in path: cycle_start = path.index(source_name) return path[cycle_start:] + [source_name] # Get source model info source_info = self._get_model_level_info(source_name, context) # If not a meta-model, no further dependencies if not source_info.is_meta_model: return None # Check source's sources recursively new_path = path + [source_name] new_visited = visited | {source_name} for sub_source in source_info.source_models: cycle = self._detect_circular_dependency( meta_model_name, sub_source, context, new_visited, new_path ) if cycle: return cycle return None
[docs] def clear_cache(self) -> None: """Clear the level info cache.""" self._level_cache.clear()
[docs] def get_all_levels( self, context: 'ExecutionContext' ) -> Dict[str, int]: """Get levels for all models in the prediction store. Args: context: Execution context. Returns: Dict mapping model names to their stacking levels. """ all_preds = self.prediction_store.filter_predictions(load_arrays=False) model_names = list(dict.fromkeys( p.get('model_name', '') for p in all_preds if p.get('model_name') )) levels = {} for name in model_names: info = self._get_model_level_info(name, context) levels[name] = info.level return levels
[docs] def validate_multi_level_stacking( prediction_store: 'Predictions', meta_model_name: str, source_candidates: List['ModelCandidate'], context: 'ExecutionContext', max_level: int = 3, allow_meta_sources: bool = True ) -> LevelValidationResult: """Convenience function for validating multi-level stacking. Args: prediction_store: Predictions storage. meta_model_name: Name of the meta-model. source_candidates: List of candidate source models. context: Execution context. max_level: Maximum allowed stacking level. allow_meta_sources: Whether to allow meta-model sources. Returns: LevelValidationResult with validation status. """ validator = MultiLevelValidator( prediction_store=prediction_store, max_level=max_level, log_warnings=True ) return validator.validate_sources( meta_model_name=meta_model_name, source_candidates=source_candidates, context=context, allow_meta_sources=allow_meta_sources )
[docs] def detect_stacking_level( prediction_store: 'Predictions', source_candidates: List['ModelCandidate'], context: 'ExecutionContext' ) -> int: """Convenience function for detecting stacking level. Args: prediction_store: Predictions storage. source_candidates: List of candidate source models. context: Execution context. Returns: Detected stacking level. """ validator = MultiLevelValidator( prediction_store=prediction_store, log_warnings=False ) return validator.detect_level(source_candidates, context)