Source code for nirs4all.controllers.models.stacking.branch_validator

"""
Branch Validator for Meta-Model Stacking (Phase 4).

This module provides validation logic for stacking in branched pipelines,
including support for:
- Preprocessing branches (same samples, different features)
- Sample partitioner branches (different sample subsets)
- Outlier excluder branches (same samples, different exclusions)
- Generator syntax branches (same samples, model variants)

The validator ensures that stacking is only performed in compatible
scenarios and provides clear error messages for unsupported cases.
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import warnings

from .exceptions import (
    IncompatibleBranchTypeError,
    CrossPartitionStackingError,
    NestedBranchStackingError,
    FoldMismatchAcrossBranchesError,
    DisjointSampleSetsError,
    GeneratorSyntaxStackingWarning,
)

if TYPE_CHECKING:
    from nirs4all.data.predictions import Predictions
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext


[docs] class BranchType(Enum): """Types of branching in nirs4all pipelines.""" NONE = "none" # No branching PREPROCESSING = "preprocessing" # {"branch": {...}} SAMPLE_PARTITIONER = "sample_partitioner" # {"branch": {"by": "sample_partitioner"}} METADATA_PARTITIONER = "metadata_partitioner" # {"branch": {"by": "metadata_partitioner", "column": ...}} OUTLIER_EXCLUDER = "outlier_excluder" # {"branch": {"by": "outlier_excluder"}} GENERATOR = "generator" # {"_or_": [...]} or generator syntax NESTED = "nested" # Multiple levels of branching UNKNOWN = "unknown" # Unrecognized branch type
[docs] class StackingCompatibility(Enum): """Compatibility level for stacking with a branch type.""" COMPATIBLE = "compatible" # Fully supported COMPATIBLE_WITH_WARNINGS = "compatible_with_warnings" # Supported but with caveats WITHIN_PARTITION_ONLY = "within_partition_only" # Only within same partition NOT_SUPPORTED = "not_supported" # Not currently supported
[docs] @dataclass class BranchInfo: """Information about branch context for stacking validation.""" branch_type: BranchType branch_id: Optional[int] = None branch_name: Optional[str] = None branch_path: List[int] = field(default_factory=list) partition_info: Optional[Dict[str, Any]] = None exclusion_info: Optional[Dict[str, Any]] = None sample_indices: Optional[List[int]] = None n_samples: Optional[int] = None is_nested: bool = False nesting_depth: int = 0
[docs] @dataclass class BranchValidationResult: """Result of branch validation for stacking.""" is_valid: bool compatibility: StackingCompatibility branch_info: BranchInfo errors: List[str] = field(default_factory=list) warnings: List[str] = field(default_factory=list) source_filter_hint: Optional[Dict[str, Any]] = None
[docs] def add_error(self, message: str) -> None: """Add an error message.""" self.errors.append(message) self.is_valid = False
[docs] def add_warning(self, message: str) -> None: """Add a warning message.""" self.warnings.append(message)
[docs] class BranchValidator: """Validates branch contexts for meta-model stacking. This validator checks that the current branch context is compatible with stacking and provides clear error messages for unsupported cases. Supported scenarios: - No branching: Fully compatible - Preprocessing branches: Stack within branch - Outlier excluder branches: Stack within branch (all samples have predictions) - Sample partitioner branches: Stack within partition only Unsupported or limited scenarios: - Cross-partition stacking with sample_partitioner - Deeply nested branching (depth > 2) - Generator syntax with large variant counts Example: >>> validator = BranchValidator(prediction_store) >>> result = validator.validate(context, source_model_names) >>> if not result.is_valid: ... raise ValueError(result.errors[0]) """ # Maximum supported nesting depth MAX_NESTING_DEPTH = 2 # Maximum variants before warning for generator syntax MAX_GENERATOR_VARIANTS_WARNING = 10 def __init__( self, prediction_store: 'Predictions', log_warnings: bool = True ): """Initialize branch validator. Args: prediction_store: Predictions storage for analyzing branch data. log_warnings: If True, emit Python warnings for non-critical issues. """ self.prediction_store = prediction_store self.log_warnings = log_warnings
[docs] def validate( self, context: 'ExecutionContext', source_model_names: List[str], dataset: Optional['SpectroDataset'] = None ) -> BranchValidationResult: """Validate branch context for stacking compatibility. Args: context: Current execution context with branch info. source_model_names: List of source model names to validate. dataset: Optional dataset for sample index validation. Returns: BranchValidationResult with validation status and any errors. """ # Extract branch information branch_info = self._extract_branch_info(context) # Initialize result result = BranchValidationResult( is_valid=True, compatibility=StackingCompatibility.COMPATIBLE, branch_info=branch_info ) # No branching - fully compatible if branch_info.branch_type == BranchType.NONE: return result # Check nesting depth if branch_info.nesting_depth > self.MAX_NESTING_DEPTH: result.add_error( f"Deeply nested branching (depth={branch_info.nesting_depth}) " f"is not fully supported for stacking. " f"Maximum supported depth is {self.MAX_NESTING_DEPTH}." ) result.compatibility = StackingCompatibility.NOT_SUPPORTED return result # Validate based on branch type if branch_info.branch_type == BranchType.SAMPLE_PARTITIONER: self._validate_sample_partitioner(context, source_model_names, result) elif branch_info.branch_type == BranchType.METADATA_PARTITIONER: self._validate_metadata_partitioner(context, source_model_names, result) elif branch_info.branch_type == BranchType.OUTLIER_EXCLUDER: self._validate_outlier_excluder(context, source_model_names, result) elif branch_info.branch_type == BranchType.PREPROCESSING: self._validate_preprocessing_branch(context, source_model_names, result) elif branch_info.branch_type == BranchType.GENERATOR: self._validate_generator_syntax(context, source_model_names, result) elif branch_info.branch_type == BranchType.NESTED: self._validate_nested_branching(context, source_model_names, result) # Validate fold alignment across source models if result.is_valid and source_model_names: self._validate_fold_alignment_in_branch( context, source_model_names, result ) # Emit warnings if configured if self.log_warnings: for warning in result.warnings: warnings.warn(warning) return result
def _extract_branch_info(self, context: 'ExecutionContext') -> BranchInfo: """Extract branch information from execution context. Args: context: Execution context with branch data. Returns: BranchInfo with detected branch type and metadata. """ selector = context.selector custom = context.custom branch_id = getattr(selector, 'branch_id', None) branch_name = getattr(selector, 'branch_name', None) branch_path = getattr(selector, 'branch_path', None) or [] # Detect branch type from custom context if custom.get('metadata_partitioner_active'): branch_type = BranchType.METADATA_PARTITIONER partition_info = custom.get('metadata_partition', {}) sample_indices = partition_info.get('sample_indices', []) n_samples = partition_info.get('n_samples', len(sample_indices)) return BranchInfo( branch_type=branch_type, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), partition_info=partition_info, sample_indices=sample_indices, n_samples=n_samples, is_nested=len(branch_path) > 1, nesting_depth=len(branch_path) ) if custom.get('sample_partitioner_active'): branch_type = BranchType.SAMPLE_PARTITIONER partition_info = custom.get('sample_partition', {}) sample_indices = partition_info.get('sample_indices', []) n_samples = partition_info.get('n_samples', len(sample_indices)) return BranchInfo( branch_type=branch_type, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), partition_info=partition_info, sample_indices=sample_indices, n_samples=n_samples, is_nested=len(branch_path) > 1, nesting_depth=len(branch_path) ) if custom.get('outlier_excluder_active'): exclusion_info = custom.get('outlier_exclusion', {}) or custom.get('exclusion_info', {}) return BranchInfo( branch_type=BranchType.OUTLIER_EXCLUDER, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), exclusion_info=exclusion_info, is_nested=len(branch_path) > 1, nesting_depth=len(branch_path) ) if custom.get('in_branch_mode'): # Generic branching (preprocessing or generator) branch_contexts = custom.get('branch_contexts', []) # Check if this looks like generator syntax if self._looks_like_generator_syntax(branch_contexts): return BranchInfo( branch_type=BranchType.GENERATOR, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), is_nested=len(branch_path) > 1, nesting_depth=len(branch_path) ) # Check for nested branching if len(branch_path) > 1: return BranchInfo( branch_type=BranchType.NESTED, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), is_nested=True, nesting_depth=len(branch_path) ) # Default to preprocessing branch return BranchInfo( branch_type=BranchType.PREPROCESSING, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), is_nested=len(branch_path) > 1, nesting_depth=len(branch_path) ) # No branching if branch_id is None: return BranchInfo(branch_type=BranchType.NONE) # Branch context present but type unknown return BranchInfo( branch_type=BranchType.UNKNOWN, branch_id=branch_id, branch_name=branch_name, branch_path=list(branch_path), nesting_depth=len(branch_path) if branch_path else 0 ) def _looks_like_generator_syntax( self, branch_contexts: List[Dict[str, Any]] ) -> bool: """Check if branch contexts look like generator syntax. Generator syntax typically creates many branches with similar names like 'variant_0', 'variant_1', etc. Args: branch_contexts: List of branch context dictionaries. Returns: True if this looks like generator syntax. """ if len(branch_contexts) <= 1: return False # Check for variant-like naming patterns names = [bc.get('name', '') for bc in branch_contexts] variant_patterns = ['variant_', 'or_', 'gen_', 'config_'] for pattern in variant_patterns: matching = sum(1 for n in names if n.startswith(pattern)) if matching == len(names): return True return False def _validate_sample_partitioner( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate sample partitioner branching for stacking. Sample partitioner creates disjoint sample sets, so stacking is only valid within the same partition. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ result.compatibility = StackingCompatibility.WITHIN_PARTITION_ONLY current_branch_id = result.branch_info.branch_id current_partition = result.branch_info.partition_info or {} partition_type = current_partition.get('partition_type', 'unknown') # Get sample indices for current partition current_samples = set(result.branch_info.sample_indices or []) if not current_samples: result.add_warning( f"Sample partitioner branch '{partition_type}' has no sample " f"indices recorded. Stacking will proceed but may have issues." ) # Check source models are from same partition for model_name in source_model_names: model_preds = self.prediction_store.filter_predictions( model_name=model_name, load_arrays=False ) for pred in model_preds: pred_branch_id = pred.get('branch_id') pred_samples = set(pred.get('sample_indices', [])) # Skip if no sample indices if not pred_samples: continue # Check if from different branch if pred_branch_id is not None and pred_branch_id != current_branch_id: # Check sample overlap if current_samples and pred_samples: overlap = current_samples & pred_samples overlap_ratio = len(overlap) / max(len(current_samples), 1) if overlap_ratio < 0.1: # Less than 10% overlap result.add_error( f"Source model '{model_name}' is from a different " f"partition with disjoint samples (only {100*overlap_ratio:.1f}% overlap). " f"Cross-partition stacking is not supported with sample_partitioner. " f"Stack only with models from the current '{partition_type}' partition." ) return # Add hint for source filtering result.source_filter_hint = { 'branch_id': current_branch_id, 'partition_type': partition_type } result.add_warning( f"Stacking within sample_partitioner partition '{partition_type}'. " f"Only models from the same partition will be used as sources." ) def _validate_metadata_partitioner( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate metadata partitioner branching for stacking. Metadata partitioner creates disjoint sample sets based on metadata column values. Stacking is only valid within the same partition. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ result.compatibility = StackingCompatibility.WITHIN_PARTITION_ONLY current_branch_id = result.branch_info.branch_id current_partition = result.branch_info.partition_info or {} partition_value = current_partition.get('partition_value', 'unknown') column_name = current_partition.get('column', 'unknown') # Get sample indices for current partition current_samples = set(result.branch_info.sample_indices or []) if not current_samples: result.add_warning( f"Metadata partitioner branch '{partition_value}' has no sample " f"indices recorded. Stacking will proceed but may have issues." ) # Check source models are from same partition for model_name in source_model_names: model_preds = self.prediction_store.filter_predictions( model_name=model_name, load_arrays=False ) for pred in model_preds: pred_branch_id = pred.get('branch_id') pred_samples = set(pred.get('sample_indices', [])) # Skip if no sample indices if not pred_samples: continue # Check if from different branch if pred_branch_id is not None and pred_branch_id != current_branch_id: # Check sample overlap if current_samples and pred_samples: overlap = current_samples & pred_samples overlap_ratio = len(overlap) / max(len(current_samples), 1) if overlap_ratio < 0.1: # Less than 10% overlap result.add_error( f"Source model '{model_name}' is from a different " f"metadata partition with disjoint samples " f"(only {100*overlap_ratio:.1f}% overlap). " f"Cross-partition stacking is not supported with metadata_partitioner. " f"Stack only with models from the current '{column_name}={partition_value}' partition." ) return # Add hint for source filtering result.source_filter_hint = { 'branch_id': current_branch_id, 'partition_value': partition_value, 'column': column_name } result.add_warning( f"Stacking within metadata_partitioner partition '{column_name}={partition_value}'. " f"Only models from the same partition will be used as sources." ) def _validate_outlier_excluder( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate outlier excluder branching for stacking. Outlier excluder creates branches where all samples have predictions (some just weren't used in training). This is fully compatible with stacking. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ result.compatibility = StackingCompatibility.COMPATIBLE exclusion_info = result.branch_info.exclusion_info or {} n_excluded = exclusion_info.get('n_excluded', 0) strategy = exclusion_info.get('strategy', {}) method = strategy.get('method', 'unknown') if strategy else 'baseline' if n_excluded > 0: result.add_warning( f"Stacking with outlier_excluder branch (method='{method}', " f"excluded={n_excluded} samples from training). " f"All samples have predictions, but some were not used in training." ) def _validate_preprocessing_branch( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate preprocessing branching for stacking. Preprocessing branches have the same samples but different features. This is fully compatible with stacking within the branch. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ result.compatibility = StackingCompatibility.COMPATIBLE branch_name = result.branch_info.branch_name or f"branch_{result.branch_info.branch_id}" result.add_warning( f"Stacking within preprocessing branch '{branch_name}'. " f"Only models from this branch will be used as sources." ) def _validate_generator_syntax( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate generator syntax for stacking. Generator syntax creates multiple model variants. This is compatible but may result in high-dimensional meta-features. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ result.compatibility = StackingCompatibility.COMPATIBLE_WITH_WARNINGS n_sources = len(source_model_names) if n_sources > self.MAX_GENERATOR_VARIANTS_WARNING: result.add_warning( f"Generator syntax created {n_sources} model variants for stacking. " f"This may result in high-dimensional meta-features. " f"Consider using TopKByMetricSelector to limit sources." ) def _validate_nested_branching( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate nested branching for stacking. Nested branching may have limited support depending on the combination of branch types. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ depth = result.branch_info.nesting_depth path = result.branch_info.branch_path if depth > self.MAX_NESTING_DEPTH: result.add_error( f"Nested branching depth {depth} exceeds maximum supported ({self.MAX_NESTING_DEPTH}). " f"Consider simplifying the pipeline structure." ) result.compatibility = StackingCompatibility.NOT_SUPPORTED return result.compatibility = StackingCompatibility.COMPATIBLE_WITH_WARNINGS result.add_warning( f"Stacking with nested branching (depth={depth}, path={path}). " f"Only models from the current branch path will be used as sources." ) def _validate_fold_alignment_in_branch( self, context: 'ExecutionContext', source_model_names: List[str], result: BranchValidationResult ) -> None: """Validate that fold structures are consistent within branch. Args: context: Execution context. source_model_names: Source model names. result: Validation result to update. """ current_branch_id = result.branch_info.branch_id current_step = context.state.step_number fold_counts: Dict[str, int] = {} for model_name in source_model_names: filter_kwargs = { 'model_name': model_name, 'partition': 'val', 'load_arrays': False, } if current_branch_id is not None: filter_kwargs['branch_id'] = current_branch_id preds = self.prediction_store.filter_predictions(**filter_kwargs) preds = [p for p in preds if p.get('step_idx', 0) < current_step] # Count unique fold IDs (excluding avg/w_avg) fold_ids = { str(p.get('fold_id')) for p in preds if str(p.get('fold_id', '')) not in ('avg', 'w_avg', 'None', '') } if fold_ids: fold_counts[model_name] = len(fold_ids) # Check consistency if fold_counts: unique_counts = set(fold_counts.values()) if len(unique_counts) > 1: result.add_warning( f"Source models have different fold counts: " f"{dict(fold_counts)}. This may affect OOF reconstruction." )
[docs] def validate_sample_alignment( self, source_model_names: List[str], expected_sample_indices: List[int], context: 'ExecutionContext' ) -> BranchValidationResult: """Validate that source models have predictions for expected samples. This is particularly important for sample_partitioner branches where different partitions have different samples. Args: source_model_names: List of source model names. expected_sample_indices: Expected sample indices (from current partition). context: Execution context. Returns: Validation result with any sample alignment issues. """ branch_info = self._extract_branch_info(context) result = BranchValidationResult( is_valid=True, compatibility=StackingCompatibility.COMPATIBLE, branch_info=branch_info ) expected_set = set(expected_sample_indices) current_branch_id = getattr(context.selector, 'branch_id', None) current_step = context.state.step_number for model_name in source_model_names: filter_kwargs = { 'model_name': model_name, 'partition': 'val', 'load_arrays': False, } if current_branch_id is not None: filter_kwargs['branch_id'] = current_branch_id preds = self.prediction_store.filter_predictions(**filter_kwargs) preds = [p for p in preds if p.get('step_idx', 0) < current_step] # Collect all sample indices from predictions found_samples = set() for pred in preds: sample_indices = pred.get('sample_indices', []) if sample_indices is not None: if hasattr(sample_indices, 'tolist'): sample_indices = sample_indices.tolist() found_samples.update(sample_indices) if found_samples: overlap = expected_set & found_samples overlap_ratio = len(overlap) / max(len(expected_set), 1) if overlap_ratio < 0.5: # Less than 50% overlap result.add_error( f"Source model '{model_name}' has low sample overlap " f"({100*overlap_ratio:.1f}%). Expected {len(expected_set)} samples, " f"found predictions for {len(found_samples)} samples with " f"{len(overlap)} overlapping." ) return result
[docs] def detect_branch_type(context: 'ExecutionContext') -> BranchType: """Detect the type of branching from execution context. Convenience function for quick branch type detection. Args: context: Execution context with branch info. Returns: Detected BranchType enum value. """ custom = context.custom if custom.get('metadata_partitioner_active'): return BranchType.METADATA_PARTITIONER if custom.get('sample_partitioner_active'): return BranchType.SAMPLE_PARTITIONER if custom.get('outlier_excluder_active'): return BranchType.OUTLIER_EXCLUDER if custom.get('in_branch_mode'): branch_path = getattr(context.selector, 'branch_path', None) or [] if len(branch_path) > 1: return BranchType.NESTED return BranchType.PREPROCESSING branch_id = getattr(context.selector, 'branch_id', None) if branch_id is None: return BranchType.NONE return BranchType.UNKNOWN
[docs] def is_stacking_compatible(context: 'ExecutionContext') -> bool: """Quick check if stacking is compatible with current context. Args: context: Execution context. Returns: True if stacking is likely compatible. """ branch_type = detect_branch_type(context) # Always compatible if branch_type in (BranchType.NONE, BranchType.PREPROCESSING, BranchType.OUTLIER_EXCLUDER): return True # Compatible within partition if branch_type in (BranchType.SAMPLE_PARTITIONER, BranchType.METADATA_PARTITIONER): return True # But only within partition # Limited support if branch_type == BranchType.GENERATOR: return True # Check nesting depth if branch_type == BranchType.NESTED: branch_path = getattr(context.selector, 'branch_path', None) or [] return len(branch_path) <= BranchValidator.MAX_NESTING_DEPTH return True
[docs] def is_disjoint_branch(context: 'ExecutionContext') -> bool: """Check if the current branch context represents disjoint sample branching. Disjoint branches partition samples into non-overlapping sets, where each sample exists in exactly ONE branch. This is in contrast to copy branches where all branches see all samples. Disjoint branch types: - METADATA_PARTITIONER: Branches by metadata column value - SAMPLE_PARTITIONER: Branches by outlier status Copy branch types: - PREPROCESSING: All branches see all samples - GENERATOR: All branches see all samples (model variants) Args: context: Execution context with branch info. Returns: True if current context represents a disjoint sample branch. """ branch_type = detect_branch_type(context) # Disjoint branch types if branch_type in (BranchType.METADATA_PARTITIONER, BranchType.SAMPLE_PARTITIONER): return True # Check for explicit markers in context.custom custom = context.custom # metadata_partition indicates disjoint samples by metadata column if custom.get('metadata_partition') is not None: return True # sample_partition indicates disjoint samples by filter (outliers/inliers) if custom.get('sample_partition') is not None: return True return False
[docs] def get_disjoint_branch_info(context: 'ExecutionContext') -> Optional[Dict[str, Any]]: """Get information about the disjoint branch if applicable. Args: context: Execution context with branch info. Returns: Dict with partition info, or None if not a disjoint branch. Keys may include: - partition_type: "metadata" or "sample" - column: Metadata column name (for metadata partitioner) - partition_value: Value(s) for this partition - sample_indices: List of sample indices in this partition - n_samples: Number of samples in this partition """ if not is_disjoint_branch(context): return None custom = context.custom # Check for metadata_partition metadata_partition = custom.get('metadata_partition') if metadata_partition is not None: return { 'partition_type': 'metadata', 'column': metadata_partition.get('column'), 'partition_value': metadata_partition.get('partition_value'), 'partition_values': metadata_partition.get('partition_values'), 'sample_indices': metadata_partition.get('sample_indices', []), 'train_sample_indices': metadata_partition.get('train_sample_indices', []), 'n_samples': metadata_partition.get('n_samples', 0), 'n_train_samples': metadata_partition.get('n_train_samples', 0), } # Check for sample_partition sample_partition = custom.get('sample_partition') if sample_partition is not None: return { 'partition_type': 'sample', 'partition_value': sample_partition.get('partition_type'), # e.g., "outliers" or "inliers" 'sample_indices': sample_partition.get('sample_indices', []), 'n_samples': sample_partition.get('n_samples', 0), 'filter_config': sample_partition.get('filter_config'), } return None