"""
Cross-Branch Stacking Support (Phase 7).
This module provides support for stacking across multiple branches,
allowing meta-models to use predictions from models in different
preprocessing branches.
Key Features:
1. BranchScope.ALL_BRANCHES support - stack across all branches
2. Feature alignment validation - ensure samples match across branches
3. Branch compatibility detection - identify compatible vs incompatible branches
4. Cross-branch prediction aggregation - combine predictions from multiple branches
Compatibility Matrix:
| Branch Type | Cross-Branch Stacking | Notes |
|---------------------|----------------------|----------------------------------|
| Preprocessing | ✅ Supported | Same samples, different features |
| Generator (_or_) | ✅ Supported | Same samples, model variants |
| Outlier Excluder | ✅ Supported | Same samples, different training |
| Sample Partitioner | ❌ Not Supported | Different samples per partition |
Example:
>>> validator = CrossBranchValidator(prediction_store)
>>> result = validator.validate_cross_branch_stacking(
... source_candidates=candidates,
... context=context
... )
>>> if result.is_compatible:
... aligned_features = validator.align_branch_features(...)
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import warnings
import numpy as np
from .branch_validator import BranchType, detect_branch_type
from .exceptions import (
IncompatibleBranchSamplesError,
BranchFeatureAlignmentError,
CrossPartitionStackingError,
)
if TYPE_CHECKING:
from nirs4all.data.predictions import Predictions
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.config.context import ExecutionContext
from nirs4all.operators.models.selection import ModelCandidate
[docs]
class CrossBranchCompatibility(Enum):
"""Compatibility level for cross-branch stacking."""
COMPATIBLE = "compatible"
COMPATIBLE_WITH_ALIGNMENT = "compatible_with_alignment"
INCOMPATIBLE_SAMPLES = "incompatible_samples"
INCOMPATIBLE_PARTITIONS = "incompatible_partitions"
NOT_APPLICABLE = "not_applicable"
[docs]
@dataclass
class BranchPredictionInfo:
"""Information about predictions from a specific branch.
Attributes:
branch_id: Unique branch identifier.
branch_name: Human-readable branch name.
model_names: List of model names in this branch.
sample_indices: Set of sample indices with predictions.
n_samples: Number of samples.
n_folds: Number of folds.
branch_type: Type of branching.
"""
branch_id: int
branch_name: Optional[str]
model_names: List[str]
sample_indices: Set[int]
n_samples: int
n_folds: int
branch_type: BranchType = BranchType.UNKNOWN
[docs]
@dataclass
class CrossBranchValidationResult:
"""Result of cross-branch stacking validation.
Attributes:
is_compatible: Whether cross-branch stacking is possible.
compatibility: Detailed compatibility level.
branches: Dict of BranchPredictionInfo by branch_id.
common_samples: Set of samples present in all branches.
alignment_issues: List of alignment problems found.
warnings: List of warning messages.
errors: List of error messages.
"""
is_compatible: bool = True
compatibility: CrossBranchCompatibility = CrossBranchCompatibility.NOT_APPLICABLE
branches: Dict[int, BranchPredictionInfo] = field(default_factory=dict)
common_samples: Set[int] = field(default_factory=set)
alignment_issues: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
[docs]
def add_warning(self, message: str) -> None:
"""Add a warning message."""
self.warnings.append(message)
[docs]
def add_error(self, message: str) -> None:
"""Add an error and mark as incompatible."""
self.errors.append(message)
self.is_compatible = False
@property
def total_models(self) -> int:
"""Total number of models across all branches."""
return sum(len(b.model_names) for b in self.branches.values())
[docs]
class CrossBranchValidator:
"""Validates and supports cross-branch stacking.
This validator checks that stacking across multiple branches is feasible
and provides utilities for aligning predictions from different branches.
Attributes:
prediction_store: Predictions storage.
log_warnings: Whether to emit Python warnings.
"""
def __init__(
self,
prediction_store: 'Predictions',
log_warnings: bool = True
):
"""Initialize cross-branch validator.
Args:
prediction_store: Predictions storage.
log_warnings: Whether to emit Python warnings.
"""
self.prediction_store = prediction_store
self.log_warnings = log_warnings
[docs]
def validate_cross_branch_stacking(
self,
source_candidates: List['ModelCandidate'],
context: 'ExecutionContext',
dataset: Optional['SpectroDataset'] = None
) -> CrossBranchValidationResult:
"""Validate cross-branch stacking feasibility.
Checks that all branches have compatible samples and predictions
can be properly aligned for stacking.
Args:
source_candidates: List of candidate source models.
context: Execution context.
dataset: Optional dataset for sample validation.
Returns:
CrossBranchValidationResult with compatibility info.
"""
result = CrossBranchValidationResult()
# Group candidates by branch
branch_groups = self._group_by_branch(source_candidates)
if len(branch_groups) <= 1:
result.compatibility = CrossBranchCompatibility.NOT_APPLICABLE
result.add_warning("Only one branch found, cross-branch stacking not needed")
return result
# Collect branch info
for branch_id, candidates in branch_groups.items():
# Skip None branch_id (shouldn't happen with ALL_BRANCHES scope)
if branch_id is None:
continue
branch_info = self._collect_branch_info(branch_id, candidates, context)
result.branches[branch_id] = branch_info
# Check for sample partitioner branches
has_sample_partitioner = any(
b.branch_type == BranchType.SAMPLE_PARTITIONER
for b in result.branches.values()
)
if has_sample_partitioner:
result.compatibility = CrossBranchCompatibility.INCOMPATIBLE_PARTITIONS
result.add_error(
"Cross-branch stacking not supported with sample_partitioner branches. "
"Different partitions have disjoint sample sets."
)
return result
# Check sample alignment across branches
all_sample_sets = [b.sample_indices for b in result.branches.values()]
if not all_sample_sets:
result.add_warning("No sample indices found in any branch")
result.compatibility = CrossBranchCompatibility.COMPATIBLE
return result
# Check if all sample sets are empty - this can happen if sample_indices
# aren't stored in prediction metadata. In this case, proceed with warning.
if all(len(s) == 0 for s in all_sample_sets):
result.add_warning(
"No sample indices available for cross-branch validation. "
"Proceeding with cross-branch stacking."
)
result.compatibility = CrossBranchCompatibility.COMPATIBLE
return result
# Find common samples across all branches
# Filter out empty sets before intersection
non_empty_sets = [s for s in all_sample_sets if len(s) > 0]
if non_empty_sets:
result.common_samples = set.intersection(*non_empty_sets)
else:
result.common_samples = set()
# Check sample coverage
total_samples = set.union(*all_sample_sets) if all_sample_sets else set()
coverage_ratio = len(result.common_samples) / max(len(total_samples), 1)
if coverage_ratio < 0.9: # Less than 90% common
if coverage_ratio < 0.5: # Less than 50% - likely incompatible
result.compatibility = CrossBranchCompatibility.INCOMPATIBLE_SAMPLES
result.add_error(
f"Branches have low sample overlap ({coverage_ratio:.1%}). "
f"Cross-branch stacking requires compatible sample sets."
)
return result
else:
result.compatibility = CrossBranchCompatibility.COMPATIBLE_WITH_ALIGNMENT
result.add_warning(
f"Branches have {coverage_ratio:.1%} sample overlap. "
f"Some samples may be dropped during cross-branch stacking."
)
else:
result.compatibility = CrossBranchCompatibility.COMPATIBLE
# Check fold alignment
fold_counts = {b.branch_id: b.n_folds for b in result.branches.values()}
unique_fold_counts = set(fold_counts.values())
if len(unique_fold_counts) > 1:
result.alignment_issues.append(
f"Different fold counts across branches: {fold_counts}"
)
result.add_warning(
f"Branches have different fold counts: {fold_counts}. "
f"OOF reconstruction may be affected."
)
# Emit warnings
if self.log_warnings:
for warning in result.warnings:
warnings.warn(warning)
return result
[docs]
def get_cross_branch_sources(
self,
source_candidates: List['ModelCandidate'],
context: 'ExecutionContext'
) -> List['ModelCandidate']:
"""Get source models from all branches for cross-branch stacking.
Filters and orders candidates for cross-branch stacking,
ensuring proper handling of branch-specific models.
Args:
source_candidates: All candidate source models.
context: Execution context.
Returns:
Filtered and ordered list of candidates for cross-branch stacking.
"""
current_step = context.state.step_number
# Filter to earlier steps only
valid_candidates = [
c for c in source_candidates
if c.step_idx < current_step
]
# Remove duplicates (same model appearing in multiple folds)
seen = set()
unique_candidates = []
for c in valid_candidates:
key = (c.model_name, c.branch_id)
if key not in seen:
seen.add(key)
unique_candidates.append(c)
# Sort by branch_id then step_idx for consistent ordering
unique_candidates.sort(key=lambda c: (
c.branch_id or 0,
c.step_idx,
c.model_name
))
return unique_candidates
[docs]
def align_branch_features(
self,
branch_features: Dict[int, np.ndarray],
branch_sample_indices: Dict[int, List[int]],
target_sample_indices: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""Align features from multiple branches to common sample order.
Combines features from different branches into a single feature matrix,
aligning samples to a common order.
Args:
branch_features: Dict mapping branch_id to feature matrix.
branch_sample_indices: Dict mapping branch_id to sample indices.
target_sample_indices: Target sample order for output.
Returns:
Tuple of (aligned_features, valid_mask).
Raises:
BranchFeatureAlignmentError: If alignment fails.
"""
n_samples = len(target_sample_indices)
target_id_to_pos = {int(sid): pos for pos, sid in enumerate(target_sample_indices)}
# Calculate total features
n_total_features = sum(f.shape[1] if f.ndim > 1 else 1 for f in branch_features.values())
# Initialize output
aligned = np.full((n_samples, n_total_features), np.nan)
valid_mask = np.ones(n_samples, dtype=bool)
feat_col = 0
alignment_issues = []
for branch_id in sorted(branch_features.keys()):
features = branch_features[branch_id]
indices = branch_sample_indices.get(branch_id, [])
n_branch_features = features.shape[1] if features.ndim > 1 else 1
if len(indices) != len(features):
alignment_issues.append(
f"Branch {branch_id}: sample count mismatch "
f"({len(indices)} indices, {len(features)} features)"
)
continue
# Map features to target positions
for i, sample_idx in enumerate(indices):
pos = target_id_to_pos.get(int(sample_idx))
if pos is not None:
if features.ndim == 1:
aligned[pos, feat_col] = features[i]
else:
aligned[pos, feat_col:feat_col + n_branch_features] = features[i]
feat_col += n_branch_features
# Update valid mask for samples with missing data
valid_mask = ~np.isnan(aligned).any(axis=1)
if alignment_issues:
raise BranchFeatureAlignmentError(
expected_features=n_total_features,
branch_features={
bid: f.shape[1] if f.ndim > 1 else 1
for bid, f in branch_features.items()
},
alignment_issues=alignment_issues
)
return aligned, valid_mask
def _group_by_branch(
self,
candidates: List['ModelCandidate']
) -> Dict[Optional[int], List['ModelCandidate']]:
"""Group candidates by branch_id.
Args:
candidates: List of model candidates.
Returns:
Dict mapping branch_id to list of candidates.
"""
groups: Dict[Optional[int], List['ModelCandidate']] = {}
for c in candidates:
branch_id = c.branch_id
if branch_id not in groups:
groups[branch_id] = []
groups[branch_id].append(c)
return groups
def _collect_branch_info(
self,
branch_id: int,
candidates: List['ModelCandidate'],
context: 'ExecutionContext'
) -> BranchPredictionInfo:
"""Collect information about predictions in a branch.
Args:
branch_id: Branch identifier.
candidates: Candidates from this branch.
context: Execution context.
Returns:
BranchPredictionInfo with branch metadata.
"""
model_names = list(dict.fromkeys(c.model_name for c in candidates))
sample_indices: Set[int] = set()
fold_ids: Set[str] = set()
branch_name = None
for c in candidates:
if c.branch_name and not branch_name:
branch_name = c.branch_name
# Get sample indices from predictions
for model_name in model_names:
preds = self.prediction_store.filter_predictions(
model_name=model_name,
branch_id=branch_id,
partition='val',
load_arrays=False
)
for pred in preds:
indices = pred.get('sample_indices', [])
if indices is not None:
if hasattr(indices, 'tolist'):
indices = indices.tolist()
sample_indices.update(int(i) for i in indices)
fold_id = pred.get('fold_id')
if fold_id and str(fold_id) not in ('avg', 'w_avg'):
fold_ids.add(str(fold_id))
# Detect branch type from context or candidates
branch_type = BranchType.PREPROCESSING
if candidates:
first_candidate = candidates[0]
# Check for sample_partitioner markers
if hasattr(first_candidate, 'partition_type'):
branch_type = BranchType.SAMPLE_PARTITIONER
return BranchPredictionInfo(
branch_id=branch_id,
branch_name=branch_name,
model_names=model_names,
sample_indices=sample_indices,
n_samples=len(sample_indices),
n_folds=len(fold_ids),
branch_type=branch_type
)
[docs]
def validate_all_branches_scope(
prediction_store: 'Predictions',
source_candidates: List['ModelCandidate'],
context: 'ExecutionContext'
) -> CrossBranchValidationResult:
"""Convenience function for validating BranchScope.ALL_BRANCHES.
Args:
prediction_store: Predictions storage.
source_candidates: List of candidate source models.
context: Execution context.
Returns:
CrossBranchValidationResult with compatibility info.
"""
validator = CrossBranchValidator(
prediction_store=prediction_store,
log_warnings=True
)
return validator.validate_cross_branch_stacking(
source_candidates=source_candidates,
context=context
)