"""Merge Controller for branch combination and exit.
This controller is the CORE PRIMITIVE for all branch combination operations.
It handles:
1. Exiting branch mode (always, unconditionally)
2. Collecting features and/or predictions from branches
3. Enforcing OOF (out-of-fold) safety when predictions are involved
4. Creating a unified dataset for subsequent steps
Phase 1 Implementation:
- Controller registration and matching
- Configuration parsing for all syntax variants
- Branch validation utilities
Phase 3 Implementation:
- Feature collection and concatenation
- Shape mismatch handling
Phase 4 Implementation:
- Model discovery from prediction store
- OOF prediction reconstruction via TrainingSetReconstructor
- Unsafe mode with prominent warnings
- Simple prediction merge syntax
Phase 5 Implementation:
- Per-branch model selection strategies (all, best, top_k, explicit)
- Per-branch aggregation strategies (separate, mean, weighted_mean, proba_mean)
- Model ranking by validation metrics
- Advanced per-branch prediction configuration
Phase 6 Implementation:
- Mixed merging (features from some branches, predictions from others)
- Asymmetric branch detection and handling (models in some, not others)
- Different feature dimensions per branch handling
- Different model counts per branch handling
- Improved error messages with resolution suggestions (MERGE-E010, MERGE-E011)
Phase 8 Implementation:
- Prediction mode support for merge steps
- Bundle export support
- Full train/predict cycle
Phase 9 Implementation:
- Source merge (merge_sources keyword) for multi-source datasets
- Source merge strategies: concat, stack, dict
- Source incompatibility handling: error, flatten, pad, truncate
- Prediction merge (merge_predictions keyword) for late fusion
- Error codes: MERGE-E024, MERGE-E030, MERGE-E031
Example:
>>> # Simple feature merge
>>> pipeline = [
... {"branch": [[SNV()], [MSC()]]},
... {"merge": "features"},
... PLSRegression(n_components=10)
... ]
>>>
>>> # Prediction stacking
>>> pipeline = [
... {"branch": [[SNV(), PLS()], [MSC(), RF()]]},
... {"merge": "predictions"},
... {"model": Ridge()}
... ]
>>>
>>> # Source merge for multi-source datasets
>>> pipeline = [
... SNV(), # Applied to all sources
... {"merge_sources": "concat"}, # Combine NIR + markers
... {"model": PLS()}
... ]
>>>
>>> # Late fusion without branches
>>> pipeline = [
... SNV(),
... {"model": PLS()},
... {"model": RF()},
... {"merge_predictions": "all"}, # Combine predictions
... {"model": Ridge()}
... ]
Keywords: "merge", "merge_sources", "merge_predictions"
Priority: 5 (same as BranchController)
"""
import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING
import numpy as np
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.controllers.shared import ModelSelector, PredictionAggregator
from nirs4all.core.logging import get_logger
from nirs4all.operators.data.merge import (
MergeConfig,
MergeMode,
BranchPredictionConfig,
BranchType,
DisjointSelectionCriterion,
DisjointBranchInfo,
DisjointMergeMetadata,
SelectionStrategy,
AggregationStrategy,
ShapeMismatchStrategy,
SourceMergeConfig,
SourceMergeStrategy,
SourceIncompatibleStrategy,
)
from nirs4all.pipeline.execution.result import StepOutput
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
from nirs4all.data._features.feature_source import FeatureSource
from nirs4all.data.predictions import Predictions
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.pipeline.steps.parser import ParsedStep
logger = get_logger(__name__)
# =============================================================================
# Phase 5: Model Selection and Prediction Aggregation Utilities
# NOTE: ModelSelector and PredictionAggregator have been moved to
# nirs4all.controllers.shared to avoid code duplication between
# MergeController and MetaModelController (Phase 2 Stacking Restoration).
# They are imported from the shared module and re-exported here for
# backward compatibility.
# =============================================================================
# =============================================================================
# Phase 6: Asymmetric Branch Detection and Handling
# =============================================================================
[docs]
@dataclass
class BranchAnalysisResult:
"""Result of analyzing branch asymmetry.
Attributes:
branch_id: Numeric identifier of the branch.
branch_name: Name of the branch (if named).
has_models: Whether the branch contains trained models.
model_names: List of model names in this branch.
model_count: Number of models in this branch.
feature_dim: Feature dimension from this branch (or None if not extracted).
has_features: Whether the branch has feature snapshots.
"""
branch_id: int
branch_name: Optional[str]
has_models: bool
model_names: List[str]
model_count: int
feature_dim: Optional[int]
has_features: bool
[docs]
@dataclass
class AsymmetryReport:
"""Report on asymmetry across branches.
Provides detailed analysis of how branches differ, helping users
understand and resolve merge configuration issues.
Attributes:
is_asymmetric: Whether any asymmetry was detected.
has_model_asymmetry: Some branches have models, others don't.
has_model_count_asymmetry: Branches have different model counts.
has_feature_dim_asymmetry: Branches have different feature dimensions.
branches_with_models: List of branch IDs that have models.
branches_without_models: List of branch IDs without models.
model_counts: Dict mapping branch_id to model count.
feature_dims: Dict mapping branch_id to feature dimension.
summary: Human-readable summary of asymmetry.
"""
is_asymmetric: bool
has_model_asymmetry: bool
has_model_count_asymmetry: bool
has_feature_dim_asymmetry: bool
branches_with_models: List[int]
branches_without_models: List[int]
model_counts: Dict[int, int]
feature_dims: Dict[int, Optional[int]]
summary: str
# =============================================================================
# Phase 2 (Disjoint Sample Branch Merging): Detection and Analysis
# =============================================================================
[docs]
@dataclass
class DisjointBranchAnalysis:
"""Analysis result for disjoint sample branches.
Attributes:
is_disjoint: Whether branches have disjoint sample sets.
branch_type: Type of disjoint branching (metadata_partitioner, sample_partitioner).
branch_sample_counts: Dict mapping branch_id to sample count.
branch_sample_indices: Dict mapping branch_id to list of sample indices.
total_samples: Total unique samples across all branches.
partition_column: Metadata column used for partitioning (if metadata_partitioner).
"""
is_disjoint: bool
branch_type: Optional[BranchType]
branch_sample_counts: Dict[int, int]
branch_sample_indices: Dict[int, List[int]]
total_samples: int
partition_column: Optional[str] = None
[docs]
@dataclass
class DisjointMergeResult:
"""Result of disjoint sample branch merge.
Attributes:
merged_array: The merged prediction or feature array (n_total_samples, n_columns).
n_columns: Number of output columns.
select_by: Selection criterion used.
branch_info: Per-branch information about selection and merging.
column_mapping: Mapping of output columns to per-branch models.
"""
merged_array: np.ndarray
n_columns: int
select_by: str
branch_info: Dict[str, Any]
column_mapping: Dict[int, Dict[str, str]]
[docs]
def is_disjoint_branch(branch_context: Dict[str, Any]) -> bool:
"""Check if a branch context indicates disjoint sample branching.
A disjoint branch has a 'sample_partition' or 'partition_info' key
that indicates samples were partitioned (not copied) across branches.
Args:
branch_context: A single branch context dictionary.
Returns:
True if this branch is part of a disjoint sample partition.
"""
# Get context object safely
context = branch_context.get("context")
custom = getattr(context, "custom", {}) if context else {}
# Check for sample_partition key (from SamplePartitionerController)
if "sample_partition" in custom:
return True
# Check for metadata_partition key (from MetadataPartitionerController)
if "metadata_partition" in custom:
return True
# Check for partition_info key (from both partitioners)
partition_info = branch_context.get("partition_info")
if partition_info and "sample_indices" in partition_info:
return True
return False
[docs]
def detect_disjoint_branches(
branch_contexts: List[Dict[str, Any]]
) -> DisjointBranchAnalysis:
"""Detect if branches represent disjoint sample partitions.
Examines branch contexts to determine if they were created by a
partitioning controller (metadata_partitioner or sample_partitioner).
Args:
branch_contexts: List of branch context dictionaries.
Returns:
DisjointBranchAnalysis with detection results.
"""
if not branch_contexts:
return DisjointBranchAnalysis(
is_disjoint=False,
branch_type=None,
branch_sample_counts={},
branch_sample_indices={},
total_samples=0,
)
# Check if any branch has partition info
has_disjoint = False
branch_type = None
branch_sample_counts = {}
branch_sample_indices = {}
partition_column = None
all_sample_indices = set()
for bc in branch_contexts:
branch_id = bc["branch_id"]
context = bc.get("context")
partition_info = bc.get("partition_info", {})
# Check for partition indicators in context.custom
custom = context.custom if context else {}
sample_indices = None
# Check for sample_partition (SamplePartitionerController)
if "sample_partition" in custom:
has_disjoint = True
branch_type = BranchType.SAMPLE_PARTITIONER
sample_indices = custom["sample_partition"].get("sample_indices", [])
# Check for metadata_partition (MetadataPartitionerController)
elif "metadata_partition" in custom:
has_disjoint = True
branch_type = BranchType.METADATA_PARTITIONER
sample_indices = custom["metadata_partition"].get("sample_indices", [])
partition_column = custom["metadata_partition"].get("column")
# Check partition_info (fallback, from both controllers)
elif "sample_indices" in partition_info:
has_disjoint = True
# Determine type from partition_info
if partition_info.get("type") in ("outliers", "inliers"):
branch_type = BranchType.SAMPLE_PARTITIONER
else:
branch_type = BranchType.METADATA_PARTITIONER
sample_indices = partition_info.get("sample_indices", [])
if sample_indices is not None:
branch_sample_counts[branch_id] = len(sample_indices)
branch_sample_indices[branch_id] = sample_indices
all_sample_indices.update(sample_indices)
# If no disjoint branches found, return non-disjoint result
if not has_disjoint:
return DisjointBranchAnalysis(
is_disjoint=False,
branch_type=BranchType.COPY,
branch_sample_counts={},
branch_sample_indices={},
total_samples=0,
)
return DisjointBranchAnalysis(
is_disjoint=True,
branch_type=branch_type,
branch_sample_counts=branch_sample_counts,
branch_sample_indices=branch_sample_indices,
total_samples=len(all_sample_indices),
partition_column=partition_column,
)
[docs]
class AsymmetricBranchAnalyzer:
"""Utility class for analyzing branch asymmetry.
Detects and reports on asymmetry across branches, providing
detailed information for error messages and resolution suggestions.
Phase 6 Features:
- Detect model presence asymmetry (some have models, some don't)
- Detect model count asymmetry (different numbers of models)
- Detect feature dimension asymmetry
- Generate resolution suggestions for mixed merge
"""
def __init__(
self,
branch_contexts: List[Dict[str, Any]],
prediction_store: Optional[Any],
context: "ExecutionContext",
):
"""Initialize the analyzer.
Args:
branch_contexts: List of branch context dictionaries.
prediction_store: Prediction storage for model discovery.
context: Execution context.
"""
self.branch_contexts = branch_contexts
self.prediction_store = prediction_store
self.context = context
self._analysis_cache: Dict[int, BranchAnalysisResult] = {}
[docs]
def analyze_branch(self, branch_idx: int) -> BranchAnalysisResult:
"""Analyze a single branch for its characteristics.
Args:
branch_idx: Branch index to analyze.
Returns:
BranchAnalysisResult with branch characteristics.
"""
if branch_idx in self._analysis_cache:
return self._analysis_cache[branch_idx]
branch_ctx = None
for bc in self.branch_contexts:
if bc["branch_id"] == branch_idx:
branch_ctx = bc
break
if branch_ctx is None:
# Return empty result for missing branch
result = BranchAnalysisResult(
branch_id=branch_idx,
branch_name=None,
has_models=False,
model_names=[],
model_count=0,
feature_dim=None,
has_features=False,
)
self._analysis_cache[branch_idx] = result
return result
# Extract branch info
branch_id = branch_ctx["branch_id"]
branch_name = branch_ctx.get("name")
# Check for features
snapshot = branch_ctx.get("features_snapshot")
has_features = snapshot is not None and len(snapshot) > 0
# Estimate feature dimension from snapshot
feature_dim = None
if has_features:
try:
total_features = 0
for feature_source in snapshot:
# Handle different feature source types
# Check for FeatureSource (has num_2d_features property)
if hasattr(feature_source, 'num_2d_features'):
total_features += feature_source.num_2d_features
# Fallback to numpy-like shape attribute
elif hasattr(feature_source, 'shape'):
shape = feature_source.shape
if len(shape) >= 2:
# shape is (samples, processings, features) or (samples, features)
total_features += int(np.prod(shape[1:]))
feature_dim = int(total_features) if total_features > 0 else None
except Exception:
feature_dim = None
# Discover models in this branch
model_names = []
if self.prediction_store is not None:
current_step = getattr(self.context.state, 'step_number', float('inf'))
filter_kwargs = {
'branch_id': branch_id,
'partition': 'val',
'load_arrays': False,
}
predictions = self.prediction_store.filter_predictions(**filter_kwargs)
# Filter by step
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
model_names = sorted(set(p.get('model_name') for p in predictions if p.get('model_name')))
result = BranchAnalysisResult(
branch_id=branch_id,
branch_name=branch_name,
has_models=len(model_names) > 0,
model_names=model_names,
model_count=len(model_names),
feature_dim=feature_dim,
has_features=has_features,
)
self._analysis_cache[branch_idx] = result
return result
[docs]
def analyze_all(self) -> AsymmetryReport:
"""Analyze all branches for asymmetry.
Returns:
AsymmetryReport with comprehensive asymmetry analysis.
"""
# Analyze all branches
analyses = []
for bc in self.branch_contexts:
branch_id = bc["branch_id"]
analyses.append(self.analyze_branch(branch_id))
if not analyses:
return AsymmetryReport(
is_asymmetric=False,
has_model_asymmetry=False,
has_model_count_asymmetry=False,
has_feature_dim_asymmetry=False,
branches_with_models=[],
branches_without_models=[],
model_counts={},
feature_dims={},
summary="No branches to analyze.",
)
# Detect model presence asymmetry
branches_with_models = [a.branch_id for a in analyses if a.has_models]
branches_without_models = [a.branch_id for a in analyses if not a.has_models]
has_model_asymmetry = len(branches_with_models) > 0 and len(branches_without_models) > 0
# Detect model count asymmetry
model_counts = {a.branch_id: a.model_count for a in analyses}
unique_counts = set(model_counts.values())
has_model_count_asymmetry = len(unique_counts) > 1
# Detect feature dimension asymmetry
feature_dims = {a.branch_id: a.feature_dim for a in analyses}
known_dims = [d for d in feature_dims.values() if d is not None]
has_feature_dim_asymmetry = len(set(known_dims)) > 1 if known_dims else False
is_asymmetric = has_model_asymmetry or has_model_count_asymmetry or has_feature_dim_asymmetry
# Build summary
summary_parts = []
if has_model_asymmetry:
summary_parts.append(
f"Model presence asymmetry: branches {branches_with_models} have models, "
f"branches {branches_without_models} have only features"
)
if has_model_count_asymmetry:
counts_str = ", ".join(f"branch {k}: {v} models" for k, v in model_counts.items())
summary_parts.append(f"Model count asymmetry: {counts_str}")
if has_feature_dim_asymmetry:
dims_str = ", ".join(f"branch {k}: {v} features" for k, v in feature_dims.items() if v is not None)
summary_parts.append(f"Feature dimension asymmetry: {dims_str}")
summary = "; ".join(summary_parts) if summary_parts else "Branches are symmetric"
return AsymmetryReport(
is_asymmetric=is_asymmetric,
has_model_asymmetry=has_model_asymmetry,
has_model_count_asymmetry=has_model_count_asymmetry,
has_feature_dim_asymmetry=has_feature_dim_asymmetry,
branches_with_models=branches_with_models,
branches_without_models=branches_without_models,
model_counts=model_counts,
feature_dims=feature_dims,
summary=summary,
)
[docs]
def suggest_mixed_merge(self) -> Optional[str]:
"""Suggest a mixed merge configuration for asymmetric branches.
Returns:
Suggested merge configuration string, or None if not applicable.
"""
report = self.analyze_all()
if not report.has_model_asymmetry:
return None
# Build suggestion
predictions_part = f'"predictions": {report.branches_with_models}'
features_part = f'"features": {report.branches_without_models}'
return (
f'Consider mixed merge: {{"merge": {{{predictions_part}, {features_part}}}}}\n'
f"This collects OOF predictions from branches with models and features from branches without."
)
[docs]
class MergeConfigParser:
"""Parser for merge step configurations.
Handles all syntax variants and normalizes them to MergeConfig.
Supported syntaxes:
- Simple string: "features", "predictions", "all"
- Dict with keys: {"features": ..., "predictions": ..., ...}
- Legacy format: {"predictions": [0, 1]}
- Per-branch format: {"predictions": [{"branch": 0, ...}]}
"""
[docs]
@classmethod
def parse(cls, raw_config: Any) -> MergeConfig:
"""Parse raw merge configuration into MergeConfig.
Args:
raw_config: The value from {"merge": raw_config}
Returns:
Normalized MergeConfig instance.
Raises:
ValueError: If configuration format is invalid.
"""
if isinstance(raw_config, str):
return cls._parse_simple_string(raw_config)
elif isinstance(raw_config, dict):
return cls._parse_dict(raw_config)
elif isinstance(raw_config, MergeConfig):
return raw_config
else:
raise ValueError(
f"Invalid merge config type: {type(raw_config).__name__}. "
f"Expected string, dict, or MergeConfig."
)
@classmethod
def _parse_simple_string(cls, mode_str: str) -> MergeConfig:
"""Parse simple string mode: "features", "predictions", or "all".
Args:
mode_str: One of "features", "predictions", "all"
Returns:
MergeConfig for the specified mode.
Raises:
ValueError: If mode_str is not recognized.
"""
# Simple string syntax uses output_as="features" by default (legacy behavior)
# This concatenates all features horizontally into a single feature matrix
if mode_str == "features":
return MergeConfig(collect_features=True, output_as="features")
elif mode_str == "predictions":
return MergeConfig(collect_predictions=True, output_as="features")
elif mode_str == "all":
return MergeConfig(collect_features=True, collect_predictions=True, output_as="features")
else:
raise ValueError(
f"Unknown merge mode: '{mode_str}'. "
f"Expected 'features', 'predictions', or 'all'."
)
@classmethod
def _parse_dict(cls, config_dict: Dict[str, Any]) -> MergeConfig:
"""Parse dictionary configuration.
Handles:
- {"features": ...}: Feature collection config
- {"predictions": ...}: Prediction collection config
- Global options: include_original, on_missing, unsafe, output_as
- Per-branch prediction configs
Args:
config_dict: Dictionary configuration
Returns:
MergeConfig for the specified configuration.
"""
config = MergeConfig()
# Parse features configuration
if "features" in config_dict:
config.collect_features = True
feat_spec = config_dict["features"]
config.feature_branches = cls._parse_branch_spec(feat_spec)
# Parse predictions configuration
if "predictions" in config_dict:
config.collect_predictions = True
pred_spec = config_dict["predictions"]
config = cls._parse_predictions_spec(config, pred_spec)
# Parse global options
config.include_original = config_dict.get("include_original", False)
config.on_missing = config_dict.get("on_missing", "error")
config.on_shape_mismatch = config_dict.get("on_shape_mismatch", "error")
config.unsafe = config_dict.get("unsafe", False)
config.output_as = config_dict.get("output_as", "features")
config.source_names = config_dict.get("source_names")
# Parse disjoint sample branch merge options (Phase 2)
config.n_columns = config_dict.get("n_columns")
config.select_by = config_dict.get("select_by", "mse")
# Validate at least one collection mode is enabled
if not config.collect_features and not config.collect_predictions:
raise ValueError(
"Merge config must specify at least one of 'features' or 'predictions'. "
f"Got keys: {list(config_dict.keys())}"
)
return config
@classmethod
def _parse_branch_spec(
cls,
spec: Union[str, List[int], Dict[str, Any]]
) -> Union[str, List[int]]:
"""Parse branch specification for features.
Args:
spec: Branch specification:
- "all": All branches
- [0, 1, 2]: Specific branch indices
- {"branches": [0, 1]}: Dict with branches key
Returns:
"all" or list of branch indices.
"""
if spec == "all" or spec is True:
return "all"
elif isinstance(spec, list):
# Validate all are integers
if not all(isinstance(i, int) for i in spec):
raise ValueError(
f"Branch indices must be integers, got: {spec}"
)
return spec
elif isinstance(spec, dict):
if "branches" in spec:
return cls._parse_branch_spec(spec["branches"])
else:
return "all"
else:
raise ValueError(
f"Invalid branch specification: {spec}. "
f"Expected 'all', list of indices, or dict with 'branches' key."
)
@classmethod
def _parse_predictions_spec(
cls,
config: MergeConfig,
pred_spec: Union[str, List, Dict]
) -> MergeConfig:
"""Parse predictions specification.
Handles:
- "all": All predictions from all branches
- [0, 1, 2]: Simple branch indices (legacy)
- [{"branch": 0, ...}]: Per-branch configuration (advanced)
- {"branches": [...], "models": [...], ...}: Dict format
Args:
config: MergeConfig to update
pred_spec: Predictions specification
Returns:
Updated MergeConfig.
"""
if pred_spec == "all" or pred_spec is True:
config.prediction_branches = "all"
return config
elif isinstance(pred_spec, list):
# Check if it's a list of branch configs or branch indices
if len(pred_spec) == 0:
raise ValueError("Predictions branch list cannot be empty")
# Detect if this is a list of per-branch configs (dicts with keys)
# vs a list of branch indices (integers)
has_dicts = any(isinstance(item, dict) for item in pred_spec)
if has_dicts:
# Per-branch configuration: all items must be dicts with 'branch' key
config.prediction_configs = [
cls._parse_branch_prediction_config(item)
for item in pred_spec
]
else:
# Legacy: list of branch indices
if not all(isinstance(i, int) for i in pred_spec):
raise ValueError(
f"Prediction branch indices must be integers, got: {pred_spec}"
)
config.prediction_branches = pred_spec
return config
elif isinstance(pred_spec, dict):
# Dict format with branches, models, proba keys
if "branches" in pred_spec:
config.prediction_branches = cls._parse_branch_spec(
pred_spec["branches"]
)
if "models" in pred_spec:
config.model_filter = pred_spec["models"]
if "proba" in pred_spec:
config.use_proba = pred_spec["proba"]
return config
else:
raise ValueError(
f"Invalid predictions specification: {pred_spec}. "
f"Expected 'all', list of indices, list of configs, or dict."
)
@classmethod
def _parse_branch_prediction_config(
cls,
item: Dict[str, Any]
) -> BranchPredictionConfig:
"""Parse a single per-branch prediction configuration.
Args:
item: Dict with 'branch' key and optional select, metric, aggregate, etc.
Returns:
BranchPredictionConfig instance.
Raises:
ValueError: If 'branch' key is missing.
"""
if "branch" not in item:
raise ValueError(
f"Per-branch prediction config must have 'branch' key, "
f"got: {list(item.keys())}"
)
return BranchPredictionConfig(
branch=item["branch"],
select=item.get("select", "all"),
metric=item.get("metric"),
aggregate=item.get("aggregate", "separate"),
weight_metric=item.get("weight_metric"),
proba=item.get("proba", False),
sources=item.get("sources", "all"),
)
[docs]
@register_controller
class MergeController(OperatorController):
"""Controller for merging branch outputs and exiting branch mode.
This controller is the CORE PRIMITIVE for branch combination. It:
1. Collects features and/or predictions from specified branches
2. Performs horizontal concatenation of features
3. Performs OOF reconstruction for predictions (mandatory unless unsafe=True)
4. Creates a unified "merged" processing in the dataset
5. ALWAYS clears branch contexts and exits branch mode
Supported Keywords:
- "merge": Branch merging (features/predictions/both)
- "merge_sources": Source merging (multi-source datasets) [Phase 9]
- "merge_predictions": Prediction-only late fusion [Phase 9]
OOF Safety:
When predictions are merged, OOF reconstruction is MANDATORY by default.
This prevents data leakage when the merged output is used for training.
Set `unsafe=True` to disable OOF (generates prominent warnings).
Relationship to MetaModel:
MetaModel internally uses MergeController for data preparation, then
trains the meta-learner. Users can achieve the same result with:
{"merge": "predictions"}, {"model": Ridge()}
which is equivalent to:
{"model": MetaModel(Ridge())}
Attributes:
priority: Controller priority (5 = same as BranchController).
SUPPORTED_KEYWORDS: Set of keywords this controller handles.
"""
priority = 5
SUPPORTED_KEYWORDS = {"merge", "merge_sources", "merge_predictions"}
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Check if the step matches the merge controller.
Args:
step: Original step configuration
operator: Deserialized operator
keyword: Step keyword
Returns:
True if keyword is one of the supported merge keywords.
"""
return keyword in cls.SUPPORTED_KEYWORDS
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Merge controller supports multi-source datasets."""
return True
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Merge controller should execute in prediction mode."""
return True
[docs]
def execute(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute the merge step with keyword dispatch.
Dispatches to appropriate handler based on the step keyword:
- "merge": Branch merging (features/predictions/both)
- "merge_sources": Source merging (Phase 9, not yet implemented)
- "merge_predictions": Prediction-only late fusion (Phase 9, not yet implemented)
Phase 2 implementation provides:
- Configuration parsing
- Branch validation
- Branch mode exit
- Keyword dispatch framework
Subsequent phases will add:
- Feature collection (Phase 3)
- Prediction OOF reconstruction (Phase 4)
- Per-branch selection/aggregation (Phase 5)
- Source merge implementation (Phase 9)
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput)
Raises:
ValueError: If not in branch mode or configuration is invalid.
NotImplementedError: If merge_sources or merge_predictions called (Phase 9).
"""
# Determine which keyword was used
keyword = step_info.keyword
# Dispatch to appropriate handler
if keyword == "merge":
return self._execute_branch_merge(
step_info, dataset, context, runtime_context,
source, mode, loaded_binaries, prediction_store
)
elif keyword == "merge_sources":
return self._execute_source_merge(
step_info, dataset, context, runtime_context,
source, mode, loaded_binaries, prediction_store
)
elif keyword == "merge_predictions":
return self._execute_prediction_merge(
step_info, dataset, context, runtime_context,
source, mode, loaded_binaries, prediction_store
)
else:
raise ValueError(
f"Unknown merge keyword: '{keyword}'. "
f"Supported: {self.SUPPORTED_KEYWORDS}"
)
def _execute_branch_merge(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute branch merge operation.
Combines outputs from multiple branches and exits branch mode.
Phase 8 Enhancement:
In prediction mode, if branch_contexts are not available (because branches
were already processed), we reconstruct the merge from loaded metadata.
The merge step doesn't persist binary artifacts - it combines features/predictions
that were already transformed by upstream branch steps.
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput)
"""
# Parse configuration
raw_config = step_info.original_step.get("merge")
config = MergeConfigParser.parse(raw_config)
# Check for source_branch mode (different from regular branch mode)
in_source_branch_mode = context.custom.get("in_source_branch_mode", False)
source_branch_contexts = context.custom.get("source_branch_contexts", [])
if in_source_branch_mode:
return self._execute_source_branch_merge(
step_info=step_info,
dataset=dataset,
context=context,
runtime_context=runtime_context,
source=source,
mode=mode,
config=config,
source_contexts=source_branch_contexts,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
)
# Validate branch mode
branch_contexts = context.custom.get("branch_contexts", [])
in_branch_mode = context.custom.get("in_branch_mode", False)
# Phase 8: Handle prediction mode without branch_contexts
# In predict mode, branches are processed but contexts may not be available
# because the executor has already iterated through branches. We handle this
# by checking if we're in predict mode and if branch_contexts are empty.
if mode in ("predict", "explain") and not branch_contexts and not in_branch_mode:
return self._execute_branch_merge_predict_mode(
step_info=step_info,
dataset=dataset,
context=context,
runtime_context=runtime_context,
source=source,
config=config,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
)
if not branch_contexts and not in_branch_mode:
raise ValueError(
"merge requires active branch contexts. "
"Use merge only after a branch step. "
"[Error: MERGE-E020]"
)
n_branches = len(branch_contexts)
logger.info(f"Merge step: mode={config.get_merge_mode().value}, branches={n_branches}")
# Phase 2: Detect disjoint sample branches
# Disjoint branches (from metadata_partitioner or sample_partitioner) require
# special merge logic: row concatenation instead of horizontal concatenation
disjoint_analysis = detect_disjoint_branches(branch_contexts)
if disjoint_analysis.is_disjoint:
logger.info(
f" Disjoint sample branching detected: {disjoint_analysis.branch_type.value}, "
f"{disjoint_analysis.total_samples} total samples across {n_branches} branches"
)
# Phase 4: In prediction mode, handle disjoint merge specially
# Samples were already routed to branches, we just need to
# collect results back in sample order
if mode in ("predict", "explain"):
return self._execute_disjoint_branch_merge_predict_mode(
step_info=step_info,
dataset=dataset,
context=context,
runtime_context=runtime_context,
source=source,
config=config,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
)
return self._execute_disjoint_branch_merge(
step_info=step_info,
dataset=dataset,
context=context,
runtime_context=runtime_context,
source=source,
mode=mode,
config=config,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
)
# Validate branch indices
self._validate_branches(config, branch_contexts)
# Log configuration (Phase 6: enhanced with asymmetric analysis)
self._log_config(
config=config,
n_branches=n_branches,
branch_contexts=branch_contexts,
prediction_store=prediction_store,
context=context,
)
# Collect merged data
merged_parts = []
merge_info = {}
# Phase 3: Feature merging
if config.collect_features:
feature_branches = config.get_feature_branches(n_branches)
# When output_as="sources", preserve the preprocessing dimension (3D layout)
# Otherwise flatten to 2D for horizontal concatenation
preserve_preprocessing = config.output_as == "sources"
features_list, feature_info = self._collect_features(
dataset=dataset,
branch_contexts=branch_contexts,
branch_indices=feature_branches,
on_missing=config.on_missing,
on_shape_mismatch=config.on_shape_mismatch,
preserve_preprocessing=preserve_preprocessing,
)
if features_list:
merged_parts.extend(features_list)
merge_info["feature_shapes"] = feature_info.get("shapes", [])
merge_info["feature_branches_used"] = feature_info.get("branches_used", [])
logger.info(
f" Collected features from {len(features_list)} branches: "
f"shapes={feature_info.get('shapes', [])}"
)
# Phase 4: Prediction merging
if config.collect_predictions:
predictions_array, pred_info = self._collect_predictions(
dataset=dataset,
context=context,
branch_contexts=branch_contexts,
config=config,
prediction_store=prediction_store,
mode=mode,
)
if predictions_array is not None and predictions_array.size > 0:
merged_parts.append(predictions_array)
merge_info["prediction_shape"] = predictions_array.shape
merge_info["prediction_models_used"] = pred_info.get("models_used", [])
merge_info["prediction_branches_used"] = pred_info.get("branches_used", [])
merge_info["oof_reconstruction"] = pred_info.get("oof_reconstruction", True)
logger.info(
f" Collected predictions: shape={predictions_array.shape}, "
f"models={pred_info.get('models_used', [])}"
)
# Include original pre-branch features if requested
if config.include_original:
original_features = self._get_original_features(dataset, context)
if original_features is not None:
# Prepend original features
merged_parts.insert(0, original_features)
merge_info["include_original"] = True
merge_info["original_shape"] = original_features.shape
logger.info(
f" Prepended original features: shape={original_features.shape}"
)
# Check if this is a source branch merge (branches came from source_branch)
is_source_branch_merge = context.custom.get("in_source_branch_mode", False)
# Handle output_as strategy
if merged_parts:
if config.output_as == "sources":
# Each branch becomes a separate source
# For source_branch: restore original sources with new features
# For regular branch: create sources from branches
merge_info["merged_shapes"] = [p.shape for p in merged_parts]
for idx, part in enumerate(merged_parts):
source_name = f"merged_{idx}"
if config.source_names and idx < len(config.source_names):
source_name = config.source_names[idx]
elif is_source_branch_merge and idx < len(branch_contexts):
# Use original source name for source_branch merges
source_name = branch_contexts[idx].get("name", f"source_{idx}")
dataset.add_merged_features(
features=part,
processing_name=source_name,
source=idx
)
logger.info(f" Source {idx} ({source_name}): shape={part.shape}")
logger.info(f" Merged {len(merged_parts)} branches → {len(merged_parts)} sources")
elif config.output_as == "dict":
# Store as structured dictionary in context for multi-input models
merged_dict = {}
for idx, part in enumerate(merged_parts):
source_name = f"branch_{idx}"
if config.source_names and idx < len(config.source_names):
source_name = config.source_names[idx]
elif is_source_branch_merge and idx < len(branch_contexts):
source_name = branch_contexts[idx].get("name", f"source_{idx}")
merged_dict[source_name] = part
merge_info["merged_dict_keys"] = list(merged_dict.keys())
# Store in context for downstream multi-input models
result_context = context.copy()
result_context.custom["merged_sources_dict"] = merged_dict
logger.info(f" Merged as dict with keys: {list(merged_dict.keys())}")
else: # output_as == "features" - legacy behavior
# Concatenate all parts horizontally into single feature matrix
merged_features = np.concatenate(merged_parts, axis=1)
merge_info["merged_shape"] = merged_features.shape
logger.info(f" Final merged shape: {merged_features.shape}")
# Store merged features in dataset
processing_name = "merged"
if config.source_names and len(config.source_names) > 0:
processing_name = config.source_names[0]
dataset.add_merged_features(
features=merged_features,
processing_name=processing_name,
source=0 # Primary source for merged features
)
# Remove other sources - output_as="features" consolidates to single source
if dataset.features_sources() > 1:
dataset.keep_sources(0)
logger.info(f" Consolidated to single source with shape {merged_features.shape}")
else:
logger.warning(
"No features collected during merge. "
"Dataset features unchanged."
)
# ALWAYS exit branch mode (both regular and source_branch)
result_context = context.copy()
result_context.custom["branch_contexts"] = []
result_context.custom["in_branch_mode"] = False
result_context.custom["source_branch_contexts"] = []
result_context.custom["in_source_branch_mode"] = False
# Update context processing to match the new dataset processing names
# This is critical for subsequent transformers to correctly identify which
# processings to operate on after a merge
n_sources = dataset.features_sources()
new_processing = []
for sd_idx in range(n_sources):
src_processings = list(dataset.features_processings(sd_idx))
new_processing.append(src_processings)
result_context = result_context.with_processing(new_processing)
# Build metadata with serialized config for prediction mode reproducibility
metadata = {
"merge_mode": config.get_merge_mode().value,
"feature_branches": (
config.get_feature_branches(n_branches)
if config.collect_features else []
),
"prediction_branches": (
[pc.branch for pc in config.get_prediction_configs(n_branches)]
if config.collect_predictions else []
),
"include_original": config.include_original,
"output_as": config.output_as,
# Phase 8: Store serialized config for prediction mode
"merge_config": config.to_dict(),
**merge_info, # Include merge details
}
# Add unsafe warning to metadata if applicable
if config.unsafe:
metadata["unsafe_merge"] = True
logger.warning(
"⚠️ UNSAFE MERGE: OOF reconstruction disabled for predictions. "
"Training predictions are used directly, causing DATA LEAKAGE. "
"Do NOT use for final model evaluation. "
"Set unsafe=False (default) for production pipelines. "
"[Error: MERGE-E025]"
)
logger.success(
f"Merge step completed: exited branch mode. "
f"Features={config.collect_features}, Predictions={config.collect_predictions}"
f"{' [UNSAFE]' if config.unsafe else ''}"
)
return result_context, StepOutput(metadata=metadata)
def _execute_source_branch_merge(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int,
mode: str,
config: MergeConfig,
source_contexts: List[Dict[str, Any]],
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute merge for source_branch mode.
This handles merging after source_branch, where each source was processed
independently. Unlike regular branch merge, this collects features from
the dataset's sources directly (not from branch snapshots).
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset with processed sources
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
config: Parsed merge configuration
source_contexts: List of source context dictionaries
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store
Returns:
Tuple of (updated_context, StepOutput)
"""
n_sources = dataset.n_sources
logger.info(f"Source branch merge: {n_sources} sources, output_as={config.output_as}")
# Collect features from each source
merged_parts = []
source_shapes = []
source_names = []
# When output_as="sources", preserve the preprocessing dimension (3D layout)
# Otherwise flatten to 2D for horizontal concatenation
preserve_preprocessing = config.output_as == "sources"
layout = "3d" if preserve_preprocessing else "2d"
for src_idx in range(n_sources):
try:
# Get features for this source using current processing
X = dataset.x(
selector=context.selector,
layout=layout,
concat_source=False,
include_augmented=True,
include_excluded=False
)
# X is a list of per-source arrays
if isinstance(X, list) and src_idx < len(X):
features = X[src_idx]
elif not isinstance(X, list) and src_idx == 0:
features = X
else:
logger.warning(f"Source {src_idx} not found in dataset output")
continue
merged_parts.append(features)
source_shapes.append(features.shape)
# Get source name from contexts or generate default
if src_idx < len(source_contexts):
name = source_contexts[src_idx].get("source_name", f"source_{src_idx}")
else:
name = f"source_{src_idx}"
source_names.append(name)
logger.info(f" Source {src_idx} ({name}): shape={features.shape}")
except Exception as e:
logger.warning(f"Failed to collect features from source {src_idx}: {e}")
continue
if not merged_parts:
logger.warning("No source features collected during merge")
result_context = context.copy()
result_context.custom["in_source_branch_mode"] = False
result_context.custom["source_branch_contexts"] = []
return result_context, StepOutput(metadata={"error": "no_features"})
merge_info = {
"source_shapes": source_shapes,
"source_names": source_names,
"n_sources": len(merged_parts),
}
# Apply output_as strategy
if config.output_as == "sources":
# Keep as separate sources - already in the right format
# Just store merged info for metadata
for idx, (part, name) in enumerate(zip(merged_parts, source_names)):
processing_name = f"merged_{name}"
if config.source_names and idx < len(config.source_names):
processing_name = config.source_names[idx]
dataset.add_merged_features(
features=part,
processing_name=processing_name,
source=idx
)
logger.info(f" Kept {len(merged_parts)} sources with shapes {source_shapes}")
elif config.output_as == "dict":
# Store as dictionary for multi-input models
merged_dict = {name: part for name, part in zip(source_names, merged_parts)}
merge_info["merged_dict_keys"] = source_names
result_context = context.copy()
result_context.custom["merged_sources_dict"] = merged_dict
logger.info(f" Stored as dict with keys: {source_names}")
else: # output_as == "features"
# Concatenate all sources into single feature matrix
merged_features = np.concatenate(merged_parts, axis=1)
merge_info["merged_shape"] = merged_features.shape
dataset.add_merged_features(
features=merged_features,
processing_name="merged",
source=0
)
logger.info(f" Concatenated to shape {merged_features.shape}")
# Exit source branch mode
result_context = context.copy()
result_context.custom["in_source_branch_mode"] = False
result_context.custom["source_branch_contexts"] = []
# Update context processing to match the new dataset processing names
# This is critical for subsequent transformers to correctly identify which
# processings to operate on
n_sources = dataset.features_sources()
new_processing = []
for sd_idx in range(n_sources):
src_processings = list(dataset.features_processings(sd_idx))
new_processing.append(src_processings)
result_context = result_context.with_processing(new_processing)
metadata = {
"source_branch_merge": True,
"output_as": config.output_as,
"merge_config": config.to_dict(),
**merge_info,
}
logger.success(f"Source branch merge completed: {len(merged_parts)} sources → output_as={config.output_as}")
return result_context, StepOutput(metadata=metadata)
# =========================================================================
# Phase 2: Disjoint Sample Branch Merging
# =========================================================================
def _execute_disjoint_branch_merge(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int,
mode: str,
config: MergeConfig,
branch_contexts: List[Dict[str, Any]],
disjoint_analysis: DisjointBranchAnalysis,
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute merge for disjoint sample branches.
Disjoint branches partition samples such that each sample exists in
exactly ONE branch. This requires different merge semantics:
Feature merge: Validate equal feature dimensions, then concatenate rows
by sample_id to reconstruct full dataset.
Prediction merge: Select top-N models per branch (where N is the minimum
model count or explicitly specified), then reconstruct OOF predictions
by sample_id.
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
config: Parsed merge configuration with n_columns and select_by
branch_contexts: List of branch context dictionaries
disjoint_analysis: Analysis result from detect_disjoint_branches()
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store
Returns:
Tuple of (updated_context, StepOutput)
Raises:
ValueError: If feature dimensions differ across branches (for features merge)
ValueError: If n_columns exceeds minimum model count across branches
"""
n_branches = len(branch_contexts)
n_total_samples = disjoint_analysis.total_samples
logger.info(
f"Disjoint branch merge: {n_branches} branches, "
f"{n_total_samples} total samples, "
f"type={disjoint_analysis.branch_type.value}"
)
merge_info: Dict[str, Any] = {
"disjoint_merge": True,
"branch_type": disjoint_analysis.branch_type.value,
"n_branches": n_branches,
"n_total_samples": n_total_samples,
}
merged_features = None
merged_predictions = None
# ===== FEATURE MERGE =====
if config.collect_features:
merged_features, feature_info = self._collect_disjoint_features(
dataset=dataset,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
config=config,
)
merge_info.update(feature_info)
# ===== PREDICTION MERGE =====
if config.collect_predictions:
merged_predictions, pred_info = self._collect_disjoint_predictions(
dataset=dataset,
context=context,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
config=config,
prediction_store=prediction_store,
mode=mode,
)
merge_info.update(pred_info)
# Combine merged parts
merged_parts = []
if merged_features is not None:
merged_parts.append(merged_features)
if merged_predictions is not None:
merged_parts.append(merged_predictions)
# Include original features if requested
if config.include_original:
original = self._get_original_features(dataset, context)
if original is not None:
merged_parts.insert(0, original)
merge_info["include_original"] = True
merge_info["original_shape"] = original.shape
if not merged_parts:
raise ValueError(
"Disjoint branch merge resulted in empty output. "
"Check that branches have features or predictions. "
"[Error: MERGE-E040]"
)
# Validate trainability of merged result
self._validate_merged_trainability(merged_parts[0], merge_info)
# Concatenate horizontally if multiple parts
if len(merged_parts) == 1:
final_merged = merged_parts[0]
else:
final_merged = np.concatenate(merged_parts, axis=1)
merge_info["merged_shape"] = final_merged.shape
logger.info(f" Final merged shape: {final_merged.shape}")
# Store merged features in dataset
processing_name = "merged"
if config.source_names and len(config.source_names) > 0:
processing_name = config.source_names[0]
dataset.add_merged_features(
features=final_merged,
processing_name=processing_name,
source=0
)
# Remove other sources - disjoint merge consolidates to single source
if dataset.features_sources() > 1:
dataset.keep_sources(0)
# Exit branch mode
result_context = context.copy()
result_context.custom["branch_contexts"] = []
result_context.custom["in_branch_mode"] = False
result_context.custom["metadata_partitioner_active"] = False
result_context.custom["sample_partitioner_active"] = False
# Update context processing
n_sources = dataset.features_sources()
new_processing = []
for sd_idx in range(n_sources):
src_processings = list(dataset.features_processings(sd_idx))
new_processing.append(src_processings)
result_context = result_context.with_processing(new_processing)
# Build metadata
metadata = {
"merge_mode": config.get_merge_mode().value,
"disjoint_merge": True,
"branch_type": disjoint_analysis.branch_type.value,
"partition_column": disjoint_analysis.partition_column,
"merge_config": config.to_dict(),
**merge_info,
}
logger.success(
f"Disjoint branch merge completed: {n_branches} branches → "
f"shape={final_merged.shape}"
)
return result_context, StepOutput(metadata=metadata)
def _execute_disjoint_branch_merge_predict_mode(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int,
config: MergeConfig,
branch_contexts: List[Dict[str, Any]],
disjoint_analysis: DisjointBranchAnalysis,
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute disjoint merge in prediction mode.
In prediction mode, samples have already been routed to their
respective branches and processed. This method reconstructs
the merged output by collecting features/predictions from each
branch in sample order.
For feature merge:
- Each branch has processed its subset of samples
- Collect features and reconstruct in original sample order
For prediction merge:
- Models have already generated predictions for their samples
- Collect predictions and reconstruct in original sample order
Args:
step_info: Parsed step info
dataset: Dataset with branch-processed samples
context: Execution context
runtime_context: Runtime context
source: Source index
config: Merge configuration
branch_contexts: List of branch context dicts
disjoint_analysis: Disjoint branch analysis
loaded_binaries: Not used (merge has no artifacts)
prediction_store: Prediction storage
Returns:
Tuple of (updated_context, StepOutput)
"""
n_branches = len(branch_contexts)
n_total_samples = disjoint_analysis.total_samples
logger.info(
f"Disjoint branch merge (predict mode): {n_branches} branches, "
f"{n_total_samples} samples"
)
merge_info: Dict[str, Any] = {
"disjoint_merge": True,
"prediction_mode": True,
"branch_type": disjoint_analysis.branch_type.value,
"n_branches": n_branches,
"n_total_samples": n_total_samples,
}
# For feature merge, reconstruct features from branch snapshots
if config.collect_features:
try:
merged_features, feature_info = self._collect_disjoint_features(
dataset=dataset,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
config=config,
)
merge_info.update(feature_info)
# Store merged features
processing_name = "merged"
if config.source_names and len(config.source_names) > 0:
processing_name = config.source_names[0]
dataset.add_merged_features(
features=merged_features,
processing_name=processing_name,
source=0
)
logger.info(f" Merged features (predict): shape={merged_features.shape}")
except Exception as e:
logger.warning(f"Could not merge disjoint features in predict mode: {e}")
# For prediction merge, collect predictions from prediction store
if config.collect_predictions and prediction_store is not None:
try:
# Get predictions from test partition (predict mode)
predictions_array = self._collect_disjoint_predictions_predict_mode(
dataset=dataset,
context=context,
branch_contexts=branch_contexts,
disjoint_analysis=disjoint_analysis,
config=config,
prediction_store=prediction_store,
)
if predictions_array is not None:
merge_info["prediction_shape"] = predictions_array.shape
# Add predictions to merged features if also collecting features
if config.collect_features:
merged_features = np.concatenate([merged_features, predictions_array], axis=1)
dataset.add_merged_features(
features=merged_features,
processing_name="merged",
source=0
)
else:
dataset.add_merged_features(
features=predictions_array,
processing_name="merged_predictions",
source=0
)
logger.info(f" Merged predictions (predict): shape={predictions_array.shape}")
except Exception as e:
logger.warning(f"Could not collect predictions in predict mode: {e}")
# Exit branch mode
result_context = context.copy()
result_context.custom["branch_contexts"] = []
result_context.custom["in_branch_mode"] = False
result_context.custom["metadata_partitioner_active"] = False
result_context.custom["sample_partitioner_active"] = False
# Update context processing
n_sources = dataset.features_sources()
new_processing = []
for sd_idx in range(n_sources):
src_processings = list(dataset.features_processings(sd_idx))
new_processing.append(src_processings)
result_context = result_context.with_processing(new_processing)
# Build metadata
metadata = {
"merge_mode": config.get_merge_mode().value,
"disjoint_merge": True,
"prediction_mode": True,
"branch_type": disjoint_analysis.branch_type.value,
"partition_column": disjoint_analysis.partition_column,
"merge_config": config.to_dict(),
**merge_info,
}
logger.success(
f"Disjoint branch merge (predict mode) completed: {n_branches} branches"
)
return result_context, StepOutput(metadata=metadata)
def _collect_disjoint_predictions_predict_mode(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
branch_contexts: List[Dict[str, Any]],
disjoint_analysis: DisjointBranchAnalysis,
config: MergeConfig,
prediction_store: Any,
) -> Optional[np.ndarray]:
"""Collect predictions from disjoint branches in predict mode.
In predict mode, models have already generated predictions for
their respective sample subsets. This method collects those
predictions and reconstructs them in original sample order.
Args:
dataset: Dataset for sample info
context: Execution context
branch_contexts: List of branch contexts
disjoint_analysis: Disjoint branch analysis
config: Merge configuration
prediction_store: Prediction storage
Returns:
Merged predictions array or None
"""
n_total_samples = disjoint_analysis.total_samples
branch_sample_indices = disjoint_analysis.branch_sample_indices
# Query prediction store for test partition predictions
filter_kwargs = {
'partition': 'test',
'load_arrays': True,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
if not predictions:
logger.debug("No test predictions found in prediction store")
return None
# Group predictions by branch
branch_predictions: Dict[int, List[Dict[str, Any]]] = {}
for pred in predictions:
branch_id = pred.get('branch_id', 0)
if branch_id not in branch_predictions:
branch_predictions[branch_id] = []
branch_predictions[branch_id].append(pred)
# Determine output shape
# Find number of models/columns from first branch with predictions
n_columns = 1
for preds in branch_predictions.values():
model_names = set(p.get('model_name') for p in preds if p.get('model_name'))
if model_names:
n_columns = len(model_names)
break
# Initialize output array
merged = np.full((n_total_samples, n_columns), np.nan)
# Collect predictions from each branch
for branch_id, sample_indices in branch_sample_indices.items():
if branch_id not in branch_predictions:
logger.debug(f"Branch {branch_id} has no predictions")
continue
preds = branch_predictions[branch_id]
# Group by model
model_preds: Dict[str, np.ndarray] = {}
for pred in preds:
model_name = pred.get('model_name', 'model')
y_pred = pred.get('y_pred')
if y_pred is not None:
y_pred = np.asarray(y_pred).flatten()
if model_name not in model_preds:
model_preds[model_name] = y_pred
else:
# Average if multiple predictions for same model
model_preds[model_name] = np.mean(
[model_preds[model_name], y_pred], axis=0
)
# Map predictions to output columns
for col_idx, (model_name, y_pred) in enumerate(model_preds.items()):
if col_idx >= n_columns:
break
# Map predictions to sample indices
for local_idx, global_idx in enumerate(sample_indices):
if global_idx < n_total_samples and local_idx < len(y_pred):
merged[global_idx, col_idx] = y_pred[local_idx]
# Check for unfilled samples
nan_count = np.sum(np.isnan(merged))
if nan_count > 0:
logger.warning(
f"Disjoint prediction merge (predict): {nan_count} values are NaN"
)
return merged
def _collect_disjoint_features(
self,
dataset: "SpectroDataset",
branch_contexts: List[Dict[str, Any]],
disjoint_analysis: DisjointBranchAnalysis,
config: MergeConfig,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Collect features from disjoint sample branches.
For disjoint branches, features are concatenated VERTICALLY (row-wise)
by sample_id, not horizontally. All branches must produce the same
feature dimension or an error is raised.
Args:
dataset: Dataset for sample information
branch_contexts: List of branch context dictionaries
disjoint_analysis: Analysis of disjoint branches
config: Merge configuration
Returns:
Tuple of (merged_features, info_dict) where merged_features has
shape (n_total_samples, n_features).
Raises:
ValueError: If feature dimensions differ across branches.
"""
n_total_samples = disjoint_analysis.total_samples
feature_dims: Dict[str, int] = {}
branch_features: Dict[int, Tuple[np.ndarray, List[int]]] = {}
# Collect features from each branch
for bc in branch_contexts:
branch_id = bc["branch_id"]
branch_name = bc.get("name", f"branch_{branch_id}")
sample_indices = disjoint_analysis.branch_sample_indices.get(branch_id, [])
if not sample_indices:
logger.warning(f"Branch {branch_name} has no sample indices, skipping")
continue
# Extract features from branch snapshot
snapshot = bc.get("features_snapshot")
if snapshot is None:
raise ValueError(
f"Branch '{branch_name}' has no feature snapshot for disjoint merge. "
f"[Error: MERGE-E041]"
)
try:
features = self._extract_features_from_snapshot(
snapshot=snapshot,
expected_samples=len(sample_indices),
branch_idx=branch_id,
layout="2d",
)
except ValueError as e:
raise ValueError(
f"Failed to extract features from branch '{branch_name}': {e}. "
f"[Error: MERGE-E041]"
) from e
feature_dim = features.shape[1]
feature_dims[branch_name] = feature_dim
branch_features[branch_id] = (features, sample_indices)
logger.debug(
f" Branch '{branch_name}': {len(sample_indices)} samples, "
f"{feature_dim} features"
)
# Validate feature dimensions are equal
unique_dims = set(feature_dims.values())
if len(unique_dims) > 1:
dims_str = ", ".join(f"'{k}': {v}" for k, v in feature_dims.items())
raise ValueError(
f"Cannot merge features from disjoint branches with different "
f"feature dimensions: {{{dims_str}}}. "
f"Ensure all branches apply identical transformations. "
f"[Error: MERGE-E042]"
)
if not unique_dims:
raise ValueError(
"No features collected from any disjoint branch. "
"[Error: MERGE-E041]"
)
n_features = unique_dims.pop()
# Reconstruct full feature matrix by sample_id
merged = np.full((n_total_samples, n_features), np.nan)
for branch_id, (features, sample_indices) in branch_features.items():
for local_idx, global_idx in enumerate(sample_indices):
if global_idx < n_total_samples:
merged[global_idx] = features[local_idx]
# Check for any unfilled samples
nan_rows = np.any(np.isnan(merged), axis=1)
n_unfilled = np.sum(nan_rows)
if n_unfilled > 0:
logger.warning(
f"Disjoint feature merge: {n_unfilled} samples have NaN values. "
f"This may indicate sample coverage gaps."
)
# Phase 3: Build comprehensive metadata for feature merge
# Build per-branch info (for features, no model selection)
branches_info: Dict[str, DisjointBranchInfo] = {}
for branch_id, (features, sample_indices) in branch_features.items():
# Get branch name from feature_dims keys (we saved branch_name -> dim)
branch_name = None
for bc in branch_contexts:
if bc["branch_id"] == branch_id:
branch_name = bc.get("name", f"branch_{branch_id}")
break
if branch_name is None:
branch_name = f"branch_{branch_id}"
branches_info[branch_name] = DisjointBranchInfo(
n_samples=len(sample_indices),
sample_ids=sample_indices,
n_models_original=0, # Feature merge, no models
n_models_selected=0,
selected_models=[],
dropped_models=[],
)
# Build feature merge metadata
disjoint_metadata = DisjointMergeMetadata(
merge_type="disjoint_samples",
n_columns=0, # 0 for feature merge (not prediction columns)
select_by="", # Not applicable for feature merge
branches=branches_info,
column_mapping={}, # Not applicable for feature merge
is_heterogeneous=False,
feature_dim=n_features,
)
# Phase 3: Use structured logging from metadata
disjoint_metadata.log_summary(logger.info)
info = {
"feature_dims": feature_dims,
"feature_dim": n_features,
"feature_branches_used": list(branch_features.keys()),
"feature_merged_shape": merged.shape,
# Phase 3: Add structured metadata
"disjoint_metadata": disjoint_metadata.to_dict(),
}
return merged, info
def _collect_disjoint_predictions(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
branch_contexts: List[Dict[str, Any]],
disjoint_analysis: DisjointBranchAnalysis,
config: MergeConfig,
prediction_store: Optional[Any],
mode: str,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Collect predictions from disjoint sample branches.
For disjoint branches, predictions are collected per-branch and then
reconstructed by sample_id. When branches have different model counts,
we select top-N models from each branch based on the selection criterion.
Algorithm:
1. Determine N (output column count) from n_columns or min(model_counts)
2. Select top-N models per branch based on select_by criterion
3. Reconstruct OOF predictions by sample_id
Args:
dataset: Dataset for sample information
context: Execution context
branch_contexts: List of branch context dictionaries
disjoint_analysis: Analysis of disjoint branches
config: Merge configuration with n_columns and select_by
prediction_store: Prediction storage
mode: Execution mode
Returns:
Tuple of (merged_predictions, info_dict) where merged_predictions
has shape (n_total_samples, N).
Raises:
ValueError: If n_columns exceeds minimum model count
ValueError: If no predictions found in any branch
"""
if prediction_store is None:
raise ValueError(
"prediction_store is required for disjoint prediction merge. "
"[Error: MERGE-E043]"
)
n_total_samples = disjoint_analysis.total_samples
select_by = config.select_by
# Step 1: Discover models in each branch and their scores
branch_models: Dict[int, List[Dict[str, Any]]] = {}
branch_sample_indices = disjoint_analysis.branch_sample_indices
for bc in branch_contexts:
branch_id = bc["branch_id"]
branch_name = bc.get("name", f"branch_{branch_id}")
# Discover models in this branch
model_names = self._discover_branch_models(
prediction_store=prediction_store,
branch_id=branch_id,
context=context,
model_filter=config.model_filter,
)
if not model_names:
logger.warning(f"Branch '{branch_name}' has no models, skipping")
continue
# Get model scores for ranking
model_infos = []
for model_name in model_names:
score = self._get_model_score(
prediction_store=prediction_store,
model_name=model_name,
branch_id=branch_id,
metric=select_by,
context=context,
)
model_infos.append({
"name": model_name,
"score": score,
"branch_id": branch_id,
"branch_name": branch_name,
})
branch_models[branch_id] = model_infos
logger.debug(
f" Branch '{branch_name}': {len(model_infos)} models"
)
if not branch_models:
raise ValueError(
"No model predictions found in any disjoint branch. "
"[Error: MERGE-E043]"
)
# Step 2: Determine N (output column count)
model_counts = {bid: len(models) for bid, models in branch_models.items()}
min_model_count = min(model_counts.values())
max_model_count = max(model_counts.values())
if config.n_columns is not None:
n_columns = config.n_columns
if n_columns > min_model_count:
raise ValueError(
f"n_columns={n_columns} exceeds minimum model count "
f"({min_model_count}) across branches. "
f"Model counts: {model_counts}. "
f"[Error: MERGE-E044]"
)
else:
n_columns = min_model_count
logger.info(
f" Disjoint prediction merge: N={n_columns} columns "
f"(model counts: {model_counts}, select_by='{select_by}')"
)
# Step 3: Select top-N models per branch
selected_per_branch: Dict[int, List[Dict[str, Any]]] = {}
dropped_per_branch: Dict[int, List[Dict[str, Any]]] = {}
column_mapping: Dict[int, Dict[str, str]] = {i: {} for i in range(n_columns)}
for branch_id, model_infos in branch_models.items():
branch_name = model_infos[0]["branch_name"] if model_infos else f"branch_{branch_id}"
if len(model_infos) == n_columns:
# No selection needed
selected = model_infos
dropped = []
else:
# Rank by score and select top-N
if select_by == "order":
# First N in definition order
selected = model_infos[:n_columns]
dropped = model_infos[n_columns:]
elif select_by in ("r2",):
# Higher is better
sorted_models = sorted(
model_infos,
key=lambda m: m["score"] if m["score"] is not None else float('-inf'),
reverse=True
)
selected = sorted_models[:n_columns]
dropped = sorted_models[n_columns:]
else:
# Lower is better (mse, rmse, mae)
sorted_models = sorted(
model_infos,
key=lambda m: m["score"] if m["score"] is not None else float('inf'),
)
selected = sorted_models[:n_columns]
dropped = sorted_models[n_columns:]
selected_per_branch[branch_id] = selected
dropped_per_branch[branch_id] = dropped
# Build column mapping
for col_idx, model_info in enumerate(selected):
column_mapping[col_idx][branch_name] = model_info["name"]
# Step 4: Collect OOF predictions for selected models
merged = np.full((n_total_samples, n_columns), np.nan)
for branch_id, selected_models in selected_per_branch.items():
sample_indices = branch_sample_indices.get(branch_id, [])
for col_idx, model_info in enumerate(selected_models):
model_name = model_info["name"]
# Get OOF predictions for this model
oof_predictions = self._get_branch_oof_predictions(
dataset=dataset,
context=context,
prediction_store=prediction_store,
model_name=model_name,
branch_id=branch_id,
sample_indices=sample_indices,
mode=mode,
)
if oof_predictions is not None:
for local_idx, global_idx in enumerate(sample_indices):
if global_idx < n_total_samples and local_idx < len(oof_predictions):
merged[global_idx, col_idx] = oof_predictions[local_idx]
# Check for unfilled predictions
nan_count = np.sum(np.isnan(merged))
if nan_count > 0:
logger.warning(
f"Disjoint prediction merge: {nan_count} values are NaN "
f"({100 * nan_count / merged.size:.1f}% of total). "
f"This may indicate incomplete OOF coverage."
)
# Phase 3: Build comprehensive metadata using DisjointMergeMetadata
# Build per-branch info
branches_info: Dict[str, DisjointBranchInfo] = {}
for branch_id, selected_models in selected_per_branch.items():
# Get branch name
branch_name = selected_models[0]["branch_name"] if selected_models else f"branch_{branch_id}"
# Get sample indices for this branch
sample_indices = branch_sample_indices.get(branch_id, [])
# Build selected model details with column mapping
selected_model_details = []
for col_idx, model_info in enumerate(selected_models):
selected_model_details.append({
"name": model_info["name"],
"score": model_info["score"],
"column": col_idx,
})
# Build dropped model details
dropped_model_details = []
for model_info in dropped_per_branch.get(branch_id, []):
dropped_model_details.append({
"name": model_info["name"],
"score": model_info["score"],
})
branches_info[branch_name] = DisjointBranchInfo(
n_samples=len(sample_indices),
sample_ids=sample_indices,
n_models_original=model_counts.get(branch_id, 0),
n_models_selected=len(selected_models),
selected_models=selected_model_details,
dropped_models=dropped_model_details,
)
# Check if column mapping is heterogeneous (different models in same column for different branches)
is_heterogeneous = False
for col_idx, mapping in column_mapping.items():
if len(set(mapping.values())) > 1:
is_heterogeneous = True
break
# Build the full metadata object
disjoint_metadata = DisjointMergeMetadata(
merge_type="disjoint_samples",
n_columns=n_columns,
select_by=select_by,
branches=branches_info,
column_mapping=column_mapping,
is_heterogeneous=is_heterogeneous,
)
# Phase 3: Use structured logging from metadata
disjoint_metadata.log_summary(logger.info)
if is_heterogeneous or max_model_count > min_model_count:
disjoint_metadata.log_warnings(logger.warning)
# Build info dict with both legacy fields and new metadata
info = {
"prediction_n_columns": n_columns,
"prediction_select_by": select_by,
"prediction_model_counts": model_counts,
"prediction_branches_used": list(selected_per_branch.keys()),
"prediction_column_mapping": column_mapping,
"prediction_merged_shape": merged.shape,
"selected_models": {
bid: [m["name"] for m in models]
for bid, models in selected_per_branch.items()
},
"dropped_models": {
bid: [m["name"] for m in models]
for bid, models in dropped_per_branch.items()
if models
},
# Phase 3: Add structured metadata
"disjoint_metadata": disjoint_metadata.to_dict(),
}
logger.info(
f" Collected predictions: {len(selected_per_branch)} branches → "
f"shape={merged.shape}"
)
return merged, info
def _get_model_score(
self,
prediction_store: Any,
model_name: str,
branch_id: int,
metric: str,
context: "ExecutionContext",
) -> Optional[float]:
"""Get validation score for a model.
Args:
prediction_store: Prediction storage
model_name: Name of the model
branch_id: Branch ID
metric: Metric name (mse, rmse, mae, r2)
context: Execution context
Returns:
Score value, or None if not available
"""
try:
current_step = getattr(context.state, 'step_number', float('inf'))
filter_kwargs = {
'model_name': model_name,
'branch_id': branch_id,
'partition': 'val',
'load_arrays': False,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
if not predictions:
return None
# Aggregate scores across folds
scores = []
for pred in predictions:
metrics = pred.get('metrics', {})
if metric in metrics:
scores.append(metrics[metric])
# Try uppercase/lowercase variants
elif metric.upper() in metrics:
scores.append(metrics[metric.upper()])
elif metric.lower() in metrics:
scores.append(metrics[metric.lower()])
if scores:
return float(np.mean(scores))
return None
except Exception as e:
logger.debug(f"Failed to get score for {model_name}: {e}")
return None
def _get_branch_oof_predictions(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
prediction_store: Any,
model_name: str,
branch_id: int,
sample_indices: List[int],
mode: str,
) -> Optional[np.ndarray]:
"""Get OOF predictions for a model in a disjoint branch.
For disjoint branches, we need predictions only for the samples
in this branch's partition.
Args:
dataset: Dataset for sample info
context: Execution context
prediction_store: Prediction storage
model_name: Model name
branch_id: Branch ID
sample_indices: Sample indices for this branch
mode: Execution mode
Returns:
1D array of predictions for the branch's samples, or None
"""
try:
current_step = getattr(context.state, 'step_number', float('inf'))
n_branch_samples = len(sample_indices)
# Get validation predictions (OOF)
filter_kwargs = {
'model_name': model_name,
'branch_id': branch_id,
'partition': 'val',
'load_arrays': True,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
if not predictions:
logger.debug(f"No OOF predictions for {model_name} in branch {branch_id}")
return None
# Build sample_id to prediction mapping
sample_to_pred: Dict[int, List[float]] = {}
for pred in predictions:
y_pred = pred.get('y_pred')
pred_sample_indices = pred.get('sample_indices')
if y_pred is None:
continue
y_pred = np.asarray(y_pred).flatten()
if pred_sample_indices is not None:
if hasattr(pred_sample_indices, 'tolist'):
pred_sample_indices = pred_sample_indices.tolist()
for i, sid in enumerate(pred_sample_indices):
if i < len(y_pred):
if sid not in sample_to_pred:
sample_to_pred[sid] = []
sample_to_pred[sid].append(y_pred[i])
# Build output array aligned with branch sample indices
result = np.full(n_branch_samples, np.nan)
for local_idx, global_idx in enumerate(sample_indices):
if global_idx in sample_to_pred:
# Average if multiple predictions (across folds)
result[local_idx] = np.mean(sample_to_pred[global_idx])
return result
except Exception as e:
logger.warning(f"Failed to get OOF predictions for {model_name}: {e}")
return None
def _validate_merged_trainability(
self,
merged: np.ndarray,
merge_info: Dict[str, Any],
) -> None:
"""Validate that merged predictions can train a meta-model.
Checks for:
1. Non-finite values (NaN, Inf)
2. Minimum sample count
Args:
merged: Merged prediction/feature array
merge_info: Merge info dict (for error context)
Raises:
ValueError: If validation fails
"""
MIN_SAMPLES = 10
# Check for non-finite values
non_finite_mask = ~np.isfinite(merged)
non_finite_count = np.sum(non_finite_mask)
if non_finite_count > 0:
non_finite_pct = 100 * non_finite_count / merged.size
if non_finite_pct > 50:
raise ValueError(
f"Merged predictions contain {non_finite_count} non-finite values "
f"({non_finite_pct:.1f}% of total). Cannot train meta-model on invalid data. "
f"[Error: MERGE-E045]"
)
else:
logger.warning(
f"Merged predictions contain {non_finite_count} non-finite values "
f"({non_finite_pct:.1f}%). These will be imputed with column means."
)
# Impute NaN with column means
for col in range(merged.shape[1]):
col_data = merged[:, col]
mask = ~np.isfinite(col_data)
if np.any(mask):
col_mean = np.nanmean(col_data)
if np.isfinite(col_mean):
merged[mask, col] = col_mean
else:
merged[mask, col] = 0.0
# Check minimum samples
n_samples = merged.shape[0]
if n_samples < MIN_SAMPLES:
raise ValueError(
f"Merged predictions have only {n_samples} samples. "
f"Minimum {MIN_SAMPLES} required for meta-model training. "
f"[Error: MERGE-E046]"
)
def _validate_branches(
self,
config: MergeConfig,
branch_contexts: List[Dict[str, Any]]
) -> None:
"""Validate that specified branch indices exist.
Args:
config: Merge configuration
branch_contexts: Available branch contexts
Raises:
ValueError: If any specified branch index is invalid.
"""
n_branches = len(branch_contexts)
available_indices = set(range(n_branches))
available_names = {
bc.get("name", f"branch_{bc['branch_id']}"): bc["branch_id"]
for bc in branch_contexts
}
# Validate feature branches
if config.collect_features and config.feature_branches != "all":
# At this point, feature_branches is List[int] (not "all")
feature_branch_list = config.feature_branches
assert isinstance(feature_branch_list, list) # Type narrowing
self._validate_branch_indices(
feature_branch_list,
available_indices,
available_names,
"feature_branches"
)
# Validate prediction branches
if config.collect_predictions:
if config.has_per_branch_config():
# At this point, prediction_configs is List[BranchPredictionConfig]
assert config.prediction_configs is not None # Type narrowing
for pc in config.prediction_configs:
self._validate_branch_reference(
pc.branch,
available_indices,
available_names,
f"prediction_config[branch={pc.branch}]"
)
elif config.prediction_branches != "all":
# At this point, prediction_branches is List[int] (not "all")
prediction_branch_list = config.prediction_branches
assert isinstance(prediction_branch_list, list) # Type narrowing
self._validate_branch_indices(
prediction_branch_list,
available_indices,
available_names,
"prediction_branches"
)
def _validate_branch_indices(
self,
indices: List[int],
available_indices: set,
available_names: Dict[str, int],
context_name: str
) -> None:
"""Validate a list of branch indices.
Args:
indices: List of branch indices to validate
available_indices: Set of valid branch indices
available_names: Map of branch names to indices
context_name: Name for error context
Raises:
ValueError: If any index is invalid.
"""
for idx in indices:
self._validate_branch_reference(
idx, available_indices, available_names, context_name
)
def _validate_branch_reference(
self,
ref: Union[int, str],
available_indices: set,
available_names: Dict[str, int],
context_name: str
) -> None:
"""Validate a single branch reference (index or name).
Args:
ref: Branch index (int) or name (str)
available_indices: Set of valid branch indices
available_names: Map of branch names to indices
context_name: Name for error context
Raises:
ValueError: If reference is invalid.
"""
if isinstance(ref, int):
if ref not in available_indices:
raise ValueError(
f"Invalid branch index in {context_name}: {ref}. "
f"Available indices: {sorted(available_indices)}. "
f"[Error: MERGE-E021]"
)
elif isinstance(ref, str):
if ref not in available_names:
raise ValueError(
f"Invalid branch name in {context_name}: '{ref}'. "
f"Available names: {list(available_names.keys())}. "
f"[Error: MERGE-E021]"
)
else:
raise ValueError(
f"Branch reference must be int or str, got {type(ref).__name__}: {ref}"
)
def _log_config(
self,
config: MergeConfig,
n_branches: int,
branch_contexts: Optional[List[Dict[str, Any]]] = None,
prediction_store: Optional[Any] = None,
context: Optional["ExecutionContext"] = None,
) -> None:
"""Log merge configuration for debugging.
Phase 6: Enhanced logging for mixed merge and asymmetric scenarios.
Args:
config: Merge configuration
n_branches: Number of available branches
branch_contexts: Optional branch contexts for asymmetric analysis
prediction_store: Optional prediction store for model discovery
context: Optional execution context
"""
mode = config.get_merge_mode()
# Phase 6: Log mixed merge detection
if config.collect_features and config.collect_predictions:
logger.info(" Mixed merge detected: collecting both features and predictions")
if config.collect_features:
feat_branches = config.get_feature_branches(n_branches)
logger.info(f" Features: collecting from branches {feat_branches}")
if config.collect_predictions:
if config.has_per_branch_config():
# Type narrowing: has_per_branch_config() guarantees prediction_configs is not None
assert config.prediction_configs is not None
for pc in config.prediction_configs:
logger.info(
f" Predictions: branch={pc.branch}, "
f"select={pc.select}, aggregate={pc.aggregate}"
)
else:
pred_branches = config.prediction_branches
logger.info(
f" Predictions: collecting from branches {pred_branches}, "
f"models={config.model_filter or 'all'}"
)
if config.include_original:
logger.info(" Including original pre-branch features")
if config.output_as != "features":
logger.info(f" Output target: {config.output_as}")
# Phase 6: Log asymmetric branch analysis if context available
if branch_contexts and prediction_store and context:
try:
analyzer = AsymmetricBranchAnalyzer(
branch_contexts=branch_contexts,
prediction_store=prediction_store,
context=context,
)
report = analyzer.analyze_all()
if report.is_asymmetric:
logger.info(f" Asymmetric branches: {report.summary}")
if report.has_model_asymmetry and not (config.collect_features and config.collect_predictions):
# User is not using mixed merge but branches are asymmetric
suggestion = analyzer.suggest_mixed_merge()
if suggestion:
logger.warning(
f"⚠️ Asymmetric branches detected but not using mixed merge. "
f"Some branches may not contribute to the result. "
f"{suggestion}"
)
except Exception as e:
# Don't fail the pipeline due to analysis errors
logger.debug(f"Asymmetric analysis failed: {e}")
# =========================================================================
# Phase 3: Feature Extraction and Collection
# =========================================================================
def _collect_features(
self,
dataset: "SpectroDataset",
branch_contexts: List[Dict[str, Any]],
branch_indices: List[int],
on_missing: str = "error",
on_shape_mismatch: str = "error",
preserve_preprocessing: bool = False,
) -> Tuple[List[np.ndarray], Dict[str, Any]]:
"""Collect features from specified branches.
Extracts features from each branch's feature snapshot. By default, features
are extracted in 2D layout (samples, features) and are horizontally
concatenated during merge. When preserve_preprocessing=True, features are
extracted in 3D layout (samples, processings, features) to preserve the
preprocessing dimension.
Args:
dataset: Dataset (used to get sample count for validation).
branch_contexts: List of branch context dictionaries.
branch_indices: List of branch indices to collect features from.
on_missing: How to handle missing snapshots:
- "error": Raise ValueError
- "warn": Log warning and skip
- "skip": Silent skip
on_shape_mismatch: Reserved for future 3D layout support.
In 2D layout (current default), features are simply concatenated
and this parameter has no effect. Will be used when 3D layout
is needed and number of processings must align:
- "error": Raise ValueError if processings differ
- "allow": Allow different processings (flatten to 2D)
- "pad": Pad shorter to match longest processings
- "truncate": Truncate longer to match shortest
preserve_preprocessing: If True, preserve the preprocessing dimension
by extracting in 3D layout. Used when output_as="sources".
Returns:
Tuple of (features_list, info_dict) where:
- features_list: List of numpy arrays (2D or 3D), one per branch
- info_dict: Dictionary with collection metadata
Raises:
ValueError: If branch is missing and on_missing="error", or
if sample counts don't match.
"""
features_list = []
shapes = []
branches_used = []
expected_samples = dataset.num_samples
layout = "3d" if preserve_preprocessing else "2d"
for branch_idx in branch_indices:
# Find branch context by index
branch_ctx = self._get_branch_context(branch_contexts, branch_idx)
if branch_ctx is None:
msg = f"Branch {branch_idx} not found in branch contexts. [Error: MERGE-E021]"
if on_missing == "error":
raise ValueError(msg)
elif on_missing == "warn":
logger.warning(msg + " Skipping.")
continue
else: # skip
continue
# Extract features from snapshot
snapshot = branch_ctx.get("features_snapshot")
if snapshot is None:
msg = f"Branch {branch_idx} has no feature snapshot. [Error: MERGE-E001]"
if on_missing == "error":
raise ValueError(msg)
elif on_missing == "warn":
logger.warning(msg + " Skipping.")
continue
else: # skip
continue
# Extract features from snapshot (2D or 3D based on preserve_preprocessing)
try:
features = self._extract_features_from_snapshot(
snapshot=snapshot,
expected_samples=expected_samples,
branch_idx=branch_idx,
layout=layout,
)
except ValueError as e:
msg = f"Failed to extract features from branch {branch_idx}: {e}"
if on_missing == "error":
raise ValueError(msg) from e
elif on_missing == "warn":
logger.warning(msg + " Skipping.")
continue
else: # skip
continue
features_list.append(features)
shapes.append(features.shape)
branches_used.append(branch_idx)
logger.debug(
f"Extracted features from branch {branch_idx}: "
f"shape={features.shape}"
)
# Validate sample counts match
if features_list:
n_samples_list = [f.shape[0] for f in features_list]
if len(set(n_samples_list)) > 1:
raise ValueError(
f"Sample count mismatch across branches: {n_samples_list}. "
f"All branches must have the same number of samples. "
f"[Error: MERGE-E003]"
)
# Note: Shape mismatch checking is NOT performed for 2D feature collection.
# In 2D layout, features from different branches are simply concatenated
# horizontally, so different feature dimensions across branches is expected
# and normal behavior. Shape mismatch handling (pad/truncate) only applies
# when using 3D layout where the number of processings must align.
# The on_shape_mismatch parameter is reserved for future 3D layout support.
info = {
"shapes": shapes,
"branches_used": branches_used,
}
return features_list, info
def _extract_features_from_snapshot(
self,
snapshot: List["FeatureSource"],
expected_samples: int,
branch_idx: int,
layout: str = "2d",
) -> np.ndarray:
"""Extract features from a branch's feature snapshot.
The snapshot is a list of FeatureSource objects (one per data source).
Each FeatureSource contains a 3D array of shape (samples, processings, features).
Args:
snapshot: List of FeatureSource objects from branch context.
expected_samples: Expected number of samples.
branch_idx: Branch index (for error messages).
layout: Feature layout to extract:
- "2d": Flatten to (n_samples, processings * features)
- "3d": Preserve as (n_samples, processings, features)
Returns:
If layout="2d": 2D numpy array of shape (n_samples, total_features)
containing all features from all sources and processings,
concatenated horizontally.
If layout="3d": 3D numpy array of shape (n_samples, processings, features)
preserving the preprocessing dimension.
Raises:
ValueError: If snapshot is empty, sample count mismatches, or
extraction fails.
"""
if not snapshot:
raise ValueError(
f"Branch {branch_idx} snapshot is empty (no feature sources)"
)
source_features = []
for src_idx, feature_source in enumerate(snapshot):
# Get number of samples in this source
n_samples = feature_source.num_samples
if n_samples != expected_samples:
raise ValueError(
f"Branch {branch_idx} source {src_idx} has {n_samples} samples, "
f"expected {expected_samples}. [Error: MERGE-E003]"
)
# Get all sample indices
sample_indices = list(range(n_samples))
# Extract features with specified layout
try:
features = feature_source.x(indices=sample_indices, layout=layout)
except Exception as e:
raise ValueError(
f"Failed to extract features from branch {branch_idx} "
f"source {src_idx}: {e}"
) from e
if features.size == 0:
logger.warning(
f"Branch {branch_idx} source {src_idx} has empty features "
f"(shape: {features.shape})"
)
continue
source_features.append(features)
if not source_features:
raise ValueError(
f"Branch {branch_idx} has no extractable features "
f"(all sources empty)"
)
# Concatenate all source features
if len(source_features) == 1:
return source_features[0]
# For 2D layout, concatenate horizontally along axis=1
# For 3D layout, concatenate along feature axis (axis=2)
concat_axis = 1 if layout == "2d" else 2
return np.concatenate(source_features, axis=concat_axis)
def _get_branch_context(
self,
branch_contexts: List[Dict[str, Any]],
branch_ref: Union[int, str]
) -> Optional[Dict[str, Any]]:
"""Get a branch context by index or name.
Args:
branch_contexts: List of branch context dictionaries.
branch_ref: Branch index (int) or name (str).
Returns:
Branch context dictionary, or None if not found.
"""
if isinstance(branch_ref, int):
for bc in branch_contexts:
if bc["branch_id"] == branch_ref:
return bc
elif isinstance(branch_ref, str):
for bc in branch_contexts:
if bc.get("name") == branch_ref:
return bc
return None
def _handle_shape_mismatch(
self,
features_list: List[np.ndarray],
strategy: str,
branches_used: List[int]
) -> List[np.ndarray]:
"""Handle feature dimension mismatches across branches.
Args:
features_list: List of 2D feature arrays.
strategy: How to handle mismatches:
- "error": Raise ValueError
- "pad": Pad shorter with zeros
- "truncate": Truncate longer to shortest
branches_used: List of branch indices (for error messages).
Returns:
List of 2D feature arrays with consistent feature dimensions.
Raises:
ValueError: If strategy is "error" and dimensions differ.
"""
if len(features_list) <= 1:
return features_list
feature_dims = [f.shape[1] for f in features_list]
# Check if all dimensions are the same
if len(set(feature_dims)) == 1:
return features_list
# Dimensions differ - apply strategy
if strategy == "error":
raise ValueError(
f"Feature dimension mismatch across branches. "
f"Branches {branches_used} have dimensions {feature_dims}. "
f"Set on_shape_mismatch='allow' to concatenate anyway, "
f"'pad' to zero-pad, or 'truncate' to truncate. "
f"[Error: MERGE-E002]"
)
elif strategy == "pad":
max_features = max(feature_dims)
padded_list = []
for i, features in enumerate(features_list):
if features.shape[1] < max_features:
pad_width = max_features - features.shape[1]
padded = np.pad(
features,
((0, 0), (0, pad_width)),
mode='constant',
constant_values=0
)
logger.info(
f" Padded branch {branches_used[i]} from "
f"{features.shape[1]} to {max_features} features"
)
padded_list.append(padded)
else:
padded_list.append(features)
return padded_list
elif strategy == "truncate":
min_features = min(feature_dims)
truncated_list = []
for i, features in enumerate(features_list):
if features.shape[1] > min_features:
logger.warning(
f" Truncating branch {branches_used[i]} from "
f"{features.shape[1]} to {min_features} features"
)
truncated_list.append(features[:, :min_features])
else:
truncated_list.append(features)
return truncated_list
# Default: allow (no modification)
return features_list
def _get_original_features(
self,
dataset: "SpectroDataset",
context: "ExecutionContext"
) -> Optional[np.ndarray]:
"""Get original pre-branch features from dataset.
Retrieves the features that were present before branching started.
This uses the context's pre_branch_features_snapshot if available,
otherwise falls back to current dataset features.
Args:
dataset: The dataset.
context: Execution context (may contain pre-branch snapshot).
Returns:
2D numpy array of original features, or None if unavailable.
"""
# Check if context has a pre-branch snapshot stored
pre_branch_snapshot = context.custom.get("pre_branch_features_snapshot")
if pre_branch_snapshot is not None:
try:
return self._extract_features_from_snapshot(
snapshot=pre_branch_snapshot,
expected_samples=dataset.num_samples,
branch_idx=-1 # -1 indicates original features
)
except Exception as e:
logger.warning(
f"Failed to extract pre-branch features: {e}. "
f"Falling back to current dataset features."
)
# Fallback: get current dataset features
# Note: This may not be ideal as current features are from a specific branch
try:
X = dataset.x(selector={}, layout="2d", concat_source=True)
# X could be ndarray or list[ndarray] depending on settings
if isinstance(X, list):
if len(X) == 0:
return None
X = np.concatenate(X, axis=1) if len(X) > 1 else X[0]
return X
except Exception as e:
logger.warning(f"Failed to get original features: {e}")
return None
# =========================================================================
# Phase 4 & 5: Prediction Collection with Per-Branch Configuration
# =========================================================================
def _collect_predictions(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
branch_contexts: List[Dict[str, Any]],
config: MergeConfig,
prediction_store: Optional["Predictions"],
mode: str = "train",
) -> Tuple[Optional[np.ndarray], Dict[str, Any]]:
"""Collect predictions from specified branches with per-branch control.
Orchestrates model discovery, selection, aggregation, and OOF reconstruction.
Supports Phase 5 per-branch configuration for selection and aggregation
strategies.
Phase 5 Features:
- Model selection per branch: all, best, top_k, explicit list
- Aggregation per branch: separate, mean, weighted_mean, proba_mean
- Model ranking by validation metrics
Args:
dataset: Dataset for sample information.
context: Execution context with branch/fold info.
branch_contexts: List of branch context dictionaries.
config: Merge configuration.
prediction_store: Prediction storage containing model predictions.
mode: Execution mode ("train" or "predict").
Returns:
Tuple of (predictions_array, info_dict) where:
- predictions_array: 2D array (n_samples, n_features) or None
- info_dict: Dictionary with collection metadata
Raises:
ValueError: If no predictions found or prediction store unavailable.
"""
if prediction_store is None:
raise ValueError(
"prediction_store is required for prediction merge. "
"Ensure models were trained in the specified branches. "
"[Error: MERGE-E010]"
)
n_branches = len(branch_contexts)
# Get prediction configs (Phase 5 or legacy)
prediction_configs = config.get_prediction_configs(n_branches)
# Initialize model selector for Phase 5 features
model_selector = ModelSelector(
prediction_store=prediction_store,
context=context,
)
# Collect predictions per branch with selection and aggregation
all_branch_predictions: List[np.ndarray] = []
all_models_used: List[str] = []
branches_used: List[int] = []
selection_info: List[Dict[str, Any]] = []
for branch_config in prediction_configs:
branch_ref = branch_config.branch
actual_idx = self._resolve_branch_index(branch_contexts, branch_ref)
branch_ctx = self._get_branch_context(branch_contexts, actual_idx)
if branch_ctx is None:
if config.on_missing == "error":
raise ValueError(
f"Branch {branch_ref} not found for prediction collection. "
f"[Error: MERGE-E021]"
)
logger.warning(f"Branch {branch_ref} not found, skipping predictions.")
continue
branch_id = branch_ctx["branch_id"]
# Discover all models in this branch
available_models = self._discover_branch_models(
prediction_store=prediction_store,
branch_id=branch_id,
context=context,
model_filter=None, # Don't pre-filter; selection does that
)
if not available_models:
if config.on_missing == "error":
# Phase 6: Provide detailed error with asymmetric branch analysis
analyzer = AsymmetricBranchAnalyzer(
branch_contexts=branch_contexts,
prediction_store=prediction_store,
context=context,
)
report = analyzer.analyze_all()
if report.has_model_asymmetry and branch_id in report.branches_without_models:
suggestion = analyzer.suggest_mixed_merge()
raise ValueError(
f"No model predictions found in branch {branch_ref}. "
f"This branch has only features (no trained models). "
f"{report.summary}. "
f"\n\n{suggestion}\n"
f"[Error: MERGE-E011]"
)
else:
raise ValueError(
f"No model predictions found in branch {branch_ref}. "
f"Ensure models were trained in this branch before merge. "
f"[Error: MERGE-E010]"
)
logger.warning(f"No models found in branch {branch_ref}, skipping.")
continue
# Phase 5: Apply model selection
selected_models = model_selector.select_models(
available_models=available_models,
config=branch_config,
branch_id=branch_id,
)
if not selected_models:
logger.warning(
f"No models selected from branch {branch_ref} after applying "
f"selection strategy: {branch_config.select}. Skipping."
)
continue
logger.info(
f" Branch {branch_id}: selected {len(selected_models)}/{len(available_models)} models "
f"using strategy '{branch_config.get_selection_strategy().value}'"
)
# Collect OOF predictions for selected models
branch_predictions = self._collect_branch_predictions(
dataset=dataset,
context=context,
prediction_store=prediction_store,
model_names=selected_models,
branch_id=branch_id,
config=config,
mode=mode,
)
if branch_predictions is None or len(branch_predictions) == 0:
logger.warning(f"No predictions collected from branch {branch_id}")
continue
# Phase 5: Apply aggregation strategy
aggregation_strategy = branch_config.get_aggregation_strategy()
if aggregation_strategy != AggregationStrategy.SEPARATE:
# Get model scores for weighted aggregation
model_scores = None
if aggregation_strategy == AggregationStrategy.WEIGHTED_MEAN:
metric = branch_config.weight_metric or branch_config.metric or "rmse"
model_scores = model_selector.get_model_scores(
model_names=selected_models,
metric=metric,
branch_id=branch_id,
)
# Aggregate predictions
aggregated = PredictionAggregator.aggregate(
predictions=branch_predictions,
strategy=aggregation_strategy,
model_scores=model_scores,
proba=branch_config.proba,
metric=branch_config.weight_metric or branch_config.metric,
)
logger.info(
f" Branch {branch_id}: aggregated {len(selected_models)} models "
f"using '{aggregation_strategy.value}' → shape {aggregated.shape}"
)
all_branch_predictions.append(aggregated)
else:
# Keep predictions separate (each model = 1 feature)
separate = PredictionAggregator.aggregate(
predictions=branch_predictions,
strategy=AggregationStrategy.SEPARATE,
)
all_branch_predictions.append(separate)
all_models_used.extend(selected_models)
branches_used.append(actual_idx)
selection_info.append({
"branch": actual_idx,
"available_models": len(available_models),
"selected_models": selected_models,
"selection_strategy": branch_config.get_selection_strategy().value,
"aggregation_strategy": aggregation_strategy.value,
})
if not all_branch_predictions:
if config.on_missing == "error":
# Phase 6: Use asymmetric analyzer for better error messages
analyzer = AsymmetricBranchAnalyzer(
branch_contexts=branch_contexts,
prediction_store=prediction_store,
context=context,
)
report = analyzer.analyze_all()
if report.has_model_asymmetry:
# Provide resolution suggestion for asymmetric branches
suggestion = analyzer.suggest_mixed_merge()
raise ValueError(
f"No model predictions found in any specified branch. "
f"Asymmetric branches detected: {report.summary}. "
f"\n\n{suggestion}\n"
f"[Error: MERGE-E011]"
)
else:
raise ValueError(
f"No model predictions found in any specified branch. "
f"Ensure models were trained in the specified branches before merge. "
f"[Error: MERGE-E010]"
)
logger.warning("No predictions collected from any branch.")
return None, {"models_used": [], "branches_used": []}
# Concatenate all branch predictions horizontally
predictions = np.concatenate(all_branch_predictions, axis=1)
info = {
"models_used": all_models_used,
"branches_used": branches_used,
"oof_reconstruction": not config.unsafe,
"n_features": predictions.shape[1],
"selection_info": selection_info,
}
return predictions, info
def _collect_branch_predictions(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
prediction_store: "Predictions",
model_names: List[str],
branch_id: int,
config: MergeConfig,
mode: str = "train",
) -> Optional[Dict[str, np.ndarray]]:
"""Collect predictions for specified models from a single branch.
Returns a dictionary mapping model names to their prediction arrays,
suitable for per-branch aggregation.
Args:
dataset: Dataset for sample information.
context: Execution context.
prediction_store: Prediction storage.
model_names: List of model names to collect.
branch_id: Branch identifier.
config: Merge configuration.
mode: Execution mode.
Returns:
Dictionary mapping model names to prediction arrays (n_samples,),
or None if no predictions found.
"""
if config.unsafe:
return self._collect_branch_predictions_unsafe(
dataset=dataset,
context=context,
prediction_store=prediction_store,
model_names=model_names,
branch_id=branch_id,
)
else:
return self._collect_branch_predictions_oof(
dataset=dataset,
context=context,
prediction_store=prediction_store,
model_names=model_names,
branch_id=branch_id,
config=config,
)
def _collect_branch_predictions_oof(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
prediction_store: "Predictions",
model_names: List[str],
branch_id: int,
config: MergeConfig,
) -> Optional[Dict[str, np.ndarray]]:
"""Collect OOF predictions for models in a single branch.
Uses TrainingSetReconstructor for proper OOF reconstruction.
Args:
dataset: Dataset for sample information.
context: Execution context.
prediction_store: Prediction storage.
model_names: List of model names.
branch_id: Branch identifier.
config: Merge configuration.
Returns:
Dictionary mapping model names to prediction arrays.
"""
from nirs4all.controllers.models.stacking import (
TrainingSetReconstructor,
ReconstructorConfig,
)
from nirs4all.operators.models.meta import StackingConfig, CoverageStrategy
# Create stacking config with IMPUTE_MEAN to handle incomplete coverage
# This is more lenient than STRICT and allows merge to work with
# samples that may not have predictions from all folds
stacking_config = StackingConfig(
coverage_strategy=CoverageStrategy.IMPUTE_MEAN,
)
reconstructor_config = ReconstructorConfig(
log_warnings=True,
validate_fold_alignment=False, # Allow fold mismatch for branch merge
)
model_predictions = {}
for model_name in model_names:
try:
reconstructor = TrainingSetReconstructor(
prediction_store=prediction_store,
source_model_names=[model_name],
stacking_config=stacking_config,
reconstructor_config=reconstructor_config,
)
result = reconstructor.reconstruct(
dataset=dataset,
context=context,
use_proba=config.use_proba,
)
# Combine train (OOF) and test predictions
n_total = dataset.num_samples
combined = np.full(n_total, np.nan)
# Get train and test sample indices
# IMPORTANT: Use include_augmented=False for train because OOF predictions
# are only available for original (non-augmented) samples
train_context = context.with_partition('train')
train_ids = dataset._indexer.x_indices(
train_context.selector,
include_augmented=False,
include_excluded=False
)
test_context = context.with_partition('test')
test_ids = dataset._indexer.x_indices(
test_context.selector,
include_augmented=False,
include_excluded=False
)
# Fill train (OOF) predictions
if result.X_train_meta.size > 0:
train_preds = result.X_train_meta[:, 0] if result.X_train_meta.ndim > 1 else result.X_train_meta
if len(train_preds) == len(train_ids):
for i, sample_id in enumerate(train_ids):
combined[sample_id] = train_preds[i]
# Fill test predictions
if result.X_test_meta.size > 0:
test_preds = result.X_test_meta[:, 0] if result.X_test_meta.ndim > 1 else result.X_test_meta
if len(test_preds) == len(test_ids):
for i, sample_id in enumerate(test_ids):
combined[sample_id] = test_preds[i]
# Propagate predictions from base samples to their augmented versions
# Augmented samples should have the same prediction as their origin
base_sample_ids = list(train_ids) + list(test_ids)
if base_sample_ids:
augmented_ids = dataset._indexer._augmentation_tracker.get_augmented_for_origins(
base_sample_ids
)
for aug_id in augmented_ids:
origin_id = dataset._indexer._augmentation_tracker.get_origin_for_sample(aug_id)
if origin_id is not None and not np.isnan(combined[origin_id]):
combined[aug_id] = combined[origin_id]
model_predictions[model_name] = combined
except Exception as e:
logger.warning(
f"Failed to collect OOF predictions for model '{model_name}' "
f"in branch {branch_id}: {e}"
)
continue
return model_predictions if model_predictions else None
def _collect_branch_predictions_unsafe(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
prediction_store: "Predictions",
model_names: List[str],
branch_id: int,
) -> Optional[Dict[str, np.ndarray]]:
"""Collect predictions WITHOUT OOF reconstruction (UNSAFE).
⚠️ WARNING: This causes DATA LEAKAGE when used for training.
Args:
dataset: Dataset for sample information.
context: Execution context.
prediction_store: Prediction storage.
model_names: List of model names.
branch_id: Branch identifier.
Returns:
Dictionary mapping model names to prediction arrays.
"""
logger.warning(
"⚠️ UNSAFE PREDICTION COLLECTION: Using training predictions directly. "
"This causes DATA LEAKAGE - do NOT use for final model evaluation!"
)
current_step = getattr(context.state, 'step_number', float('inf'))
n_total = dataset.num_samples
# Get sample indices for both partitions
train_context = context.with_partition('train')
train_ids = dataset._indexer.x_indices(
train_context.selector,
include_augmented=True,
include_excluded=False
)
test_context = context.with_partition('test')
test_ids = dataset._indexer.x_indices(
test_context.selector,
include_augmented=False,
include_excluded=False
)
model_predictions = {}
for model_name in model_names:
combined = np.full(n_total, np.nan)
# Collect train partition predictions
train_preds = self._get_unsafe_partition_predictions(
prediction_store=prediction_store,
model_name=model_name,
partition="train",
current_step=current_step,
)
if train_preds:
for pred in train_preds:
y_pred = pred.get('y_pred')
sample_indices = pred.get('sample_indices')
if y_pred is None:
continue
y_pred = np.asarray(y_pred).flatten()
if sample_indices is not None:
if hasattr(sample_indices, 'tolist'):
sample_indices = sample_indices.tolist()
for i, sid in enumerate(sample_indices):
if i < len(y_pred) and int(sid) < n_total:
combined[int(sid)] = y_pred[i]
# Collect test partition predictions
test_preds = self._get_unsafe_partition_predictions(
prediction_store=prediction_store,
model_name=model_name,
partition="test",
current_step=current_step,
)
if test_preds:
# Aggregate test predictions across folds
test_aggregated: Dict[int, List[float]] = {}
for pred in test_preds:
y_pred = pred.get('y_pred')
sample_indices = pred.get('sample_indices')
if y_pred is None:
continue
y_pred = np.asarray(y_pred).flatten()
if sample_indices is not None:
if hasattr(sample_indices, 'tolist'):
sample_indices = sample_indices.tolist()
for i, sid in enumerate(sample_indices):
if i < len(y_pred):
sample_idx = int(sid)
if sample_idx not in test_aggregated:
test_aggregated[sample_idx] = []
test_aggregated[sample_idx].append(y_pred[i])
# Average across folds
for sample_idx, values in test_aggregated.items():
if sample_idx < n_total:
combined[sample_idx] = np.mean(values)
# Replace remaining NaN with 0
combined = np.nan_to_num(combined, nan=0.0)
model_predictions[model_name] = combined
return model_predictions if model_predictions else None
def _get_unsafe_partition_predictions(
self,
prediction_store: "Predictions",
model_name: str,
partition: str,
current_step: Union[int, float],
) -> List[Dict[str, Any]]:
"""Get predictions for a model/partition without OOF.
Helper for unsafe prediction collection.
Args:
prediction_store: Prediction storage.
model_name: Model name.
partition: Partition name.
current_step: Current step for filtering.
Returns:
List of prediction dictionaries.
"""
filter_kwargs = {
'model_name': model_name,
'partition': partition,
'load_arrays': True,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
# Filter by step
return [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
def _discover_branch_models(
self,
prediction_store: "Predictions",
branch_id: int,
context: "ExecutionContext",
model_filter: Optional[List[str]] = None,
) -> List[str]:
"""Discover models that have predictions in a branch.
Queries the prediction store for models that ran in the specified
branch and returns their names.
Args:
prediction_store: Prediction storage.
branch_id: Branch ID to search for.
context: Execution context for step filtering.
model_filter: Optional list of model names to include.
Returns:
List of model names with predictions in the branch.
"""
current_step = getattr(context.state, 'step_number', float('inf'))
# Query prediction store for validation predictions in this branch
filter_kwargs = {
'branch_id': branch_id,
'partition': 'val',
'load_arrays': False,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
# Filter by step (only include predictions from earlier steps)
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
# If no predictions with branch filter, try pre-branch models
# (models trained before branch was created have branch_id=None)
if not predictions:
filter_kwargs_no_branch = {
'partition': 'val',
'load_arrays': False,
}
predictions = prediction_store.filter_predictions(**filter_kwargs_no_branch)
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step and p.get('branch_id') is None
]
# Extract unique model names
model_names = set()
for pred in predictions:
model_name = pred.get('model_name')
if model_name:
model_names.add(model_name)
# Apply model filter if specified
if model_filter:
model_names = model_names.intersection(set(model_filter))
return sorted(model_names)
def _resolve_branch_index(
self,
branch_contexts: List[Dict[str, Any]],
branch_ref: Union[int, str]
) -> int:
"""Resolve a branch reference to its numeric index.
Args:
branch_contexts: List of branch context dictionaries.
branch_ref: Branch index (int) or name (str).
Returns:
Numeric branch index.
Raises:
ValueError: If branch not found.
"""
if isinstance(branch_ref, int):
return branch_ref
elif isinstance(branch_ref, str):
for bc in branch_contexts:
if bc.get("name") == branch_ref:
return bc["branch_id"]
raise ValueError(f"Branch name '{branch_ref}' not found")
else:
raise ValueError(f"Invalid branch reference type: {type(branch_ref)}")
# =========================================================================
# Phase 8: Prediction Mode Support
# =========================================================================
def _execute_branch_merge_predict_mode(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int,
config: MergeConfig,
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None,
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute merge in prediction mode without active branch contexts.
In prediction mode, branches have already been processed by the executor
which iterates through each branch and applies their transformers. The merge
step's job in predict mode is to:
1. For feature merge: The dataset already has the merged/transformed features
from branch processing. We just need to mark branch mode as exited.
2. For prediction merge: Collect predictions from the prediction store
using the same configuration that was used during training.
The key insight is that merge doesn't persist artifacts itself - it orchestrates
the combination of outputs from branches. In predict mode, this orchestration
has already happened through the branch iteration in the executor.
Args:
step_info: Parsed step info
dataset: Dataset with branch-transformed features
context: Execution context
runtime_context: Runtime context
source: Source index
config: Parsed merge configuration
loaded_binaries: Not used (merge has no artifacts)
prediction_store: For prediction collection
Returns:
Updated context and StepOutput with prediction mode metadata
"""
logger.info(
f"Merge step (predict mode): mode={config.get_merge_mode().value}"
)
merged_parts = []
merge_info = {"prediction_mode": True}
# In predict mode for feature merge:
# The features are already in the dataset from branch processing.
# The executor has iterated through branches and applied transformers.
# We just need to collect the current features as the "merged" output.
if config.collect_features:
# Get current features from dataset
# In predict mode, the branch iteration has already produced transformed features
try:
# Try to get existing merged features if already created
current_features = dataset.x(
selector=context.selector,
layout="2d",
concat_source=True
)
if isinstance(current_features, list):
current_features = np.concatenate(current_features, axis=1)
if current_features is not None and current_features.size > 0:
merged_parts.append(current_features)
merge_info["feature_shape"] = current_features.shape
logger.info(
f" Collected features for prediction: shape={current_features.shape}"
)
except Exception as e:
logger.warning(f"Could not extract features in predict mode: {e}")
# In predict mode for prediction merge:
# We need to collect predictions from models that ran during prediction
if config.collect_predictions and prediction_store is not None:
try:
predictions_array = self._collect_predictions_predict_mode(
dataset=dataset,
context=context,
config=config,
prediction_store=prediction_store,
)
if predictions_array is not None and predictions_array.size > 0:
merged_parts.append(predictions_array)
merge_info["prediction_shape"] = predictions_array.shape
logger.info(
f" Collected predictions for prediction: shape={predictions_array.shape}"
)
except Exception as e:
logger.warning(f"Could not collect predictions in predict mode: {e}")
# Include original features if configured
if config.include_original:
original = self._get_original_features(dataset, context)
if original is not None:
merged_parts.insert(0, original)
merge_info["original_shape"] = original.shape
# Combine all parts
if merged_parts:
merged_features = np.concatenate(merged_parts, axis=1)
merge_info["merged_shape"] = merged_features.shape
logger.info(f" Final merged shape (predict): {merged_features.shape}")
# Store in dataset
processing_name = "merged"
if config.source_names and len(config.source_names) > 0:
processing_name = config.source_names[0]
dataset.add_merged_features(
features=merged_features,
processing_name=processing_name,
source=0
)
# Exit branch mode (if any residual state)
result_context = context.copy()
result_context.custom["branch_contexts"] = []
result_context.custom["in_branch_mode"] = False
# Build metadata
metadata = {
"merge_mode": config.get_merge_mode().value,
"prediction_mode": True,
"output_as": config.output_as,
**merge_info,
}
logger.success(
f"Merge step (predict mode) completed. "
f"Features={config.collect_features}, Predictions={config.collect_predictions}"
)
return result_context, StepOutput(metadata=metadata)
def _collect_predictions_predict_mode(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
config: MergeConfig,
prediction_store: "Predictions",
) -> Optional[np.ndarray]:
"""Collect predictions in prediction mode.
In predict mode, models have already generated predictions which are
stored in the prediction store. We collect and aggregate them according
to the merge configuration.
Args:
dataset: Dataset for sample info
context: Execution context
config: Merge configuration
prediction_store: Prediction storage
Returns:
Aggregated predictions array or None
"""
# Get model names from config
model_filter = config.model_filter
# Query prediction store for test partition predictions
filter_kwargs = {
'partition': 'test',
'load_arrays': True,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
if not predictions:
logger.debug("No test predictions found in prediction store")
return None
# Group by model
model_predictions: Dict[str, List[np.ndarray]] = {}
for pred in predictions:
model_name = pred.get('model_name')
if model_name is None:
continue
# Apply model filter if specified
if model_filter and model_name not in model_filter:
continue
y_pred = pred.get('y_pred')
if y_pred is not None:
y_pred = np.asarray(y_pred)
if model_name not in model_predictions:
model_predictions[model_name] = []
model_predictions[model_name].append(y_pred)
if not model_predictions:
logger.debug("No matching predictions after filtering")
return None
# Aggregate predictions per model (average across folds)
aggregated = []
for model_name, pred_list in model_predictions.items():
if len(pred_list) == 1:
model_pred = pred_list[0]
else:
# Average across folds
try:
stacked = np.stack([p.flatten() for p in pred_list], axis=0)
model_pred = np.mean(stacked, axis=0)
except Exception:
model_pred = pred_list[0]
# Ensure 1D
model_pred = model_pred.flatten()
aggregated.append(model_pred.reshape(-1, 1))
if not aggregated:
return None
return np.hstack(aggregated)
def _execute_source_merge(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute source merge operation (Phase 9).
Combines features from multiple data sources in a multi-source dataset.
This is distinct from branch merging - it operates on the data provenance
dimension (different sensors/instruments) rather than pipeline execution
dimension (parallel processing paths).
Supports three merge strategies:
- concat: Horizontal concatenation (2D result)
- stack: Stack along new axis (3D result, requires uniform shapes)
- dict: Keep as structured dictionary (for multi-input models)
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput)
Raises:
ValueError: If dataset has only one source (warning in single-source case).
"""
# Parse configuration
raw_config = step_info.original_step.get("merge_sources")
config = self._parse_source_merge_config(raw_config)
# Validate multi-source dataset
n_sources = dataset.n_sources
if n_sources == 0:
raise ValueError(
"merge_sources requires a dataset with feature sources. "
"No sources found in dataset. "
"[Error: MERGE-E024]"
)
if n_sources == 1:
# Single source - warn but don't fail
logger.warning(
"merge_sources called on single-source dataset. "
"This is a no-op - the dataset already has unified features. "
"Consider removing this step. [Warning: MERGE-E024]"
)
return context.copy(), StepOutput(metadata={
"source_merge": "no-op",
"n_sources": 1,
"reason": "single_source_dataset",
})
# Get source names for logging and selection
source_names = self._get_source_names(dataset, n_sources)
logger.info(
f"Source merge: strategy={config.strategy}, "
f"sources={config.sources}, n_sources={n_sources}"
)
# Resolve source indices
try:
source_indices = config.get_source_indices(source_names)
except ValueError as e:
raise ValueError(str(e))
if len(source_indices) < 2:
logger.warning(
f"Only {len(source_indices)} source(s) selected for merge. "
"Merge requires at least 2 sources to be meaningful."
)
# Collect features from each source
source_features, source_info = self._collect_source_features(
dataset=dataset,
context=context,
source_indices=source_indices,
source_names=source_names,
)
if not source_features:
raise ValueError(
"No features collected from any source. "
"[Error: MERGE-E030]"
)
# Apply merge strategy
strategy = config.get_strategy()
if strategy == SourceMergeStrategy.CONCAT:
merged_features, merge_info = self._merge_sources_concat(
source_features=source_features,
source_indices=source_indices,
source_names=source_names,
)
elif strategy == SourceMergeStrategy.STACK:
merged_features, merge_info = self._merge_sources_stack(
source_features=source_features,
source_indices=source_indices,
source_names=source_names,
on_incompatible=config.get_incompatible_strategy(),
)
elif strategy == SourceMergeStrategy.DICT:
merged_features, merge_info = self._merge_sources_dict(
source_features=source_features,
source_indices=source_indices,
source_names=source_names,
)
else:
raise ValueError(f"Unknown merge strategy: {strategy}")
# Store merged features in dataset
# For dict strategy, we need special handling
if strategy == SourceMergeStrategy.DICT:
# Dict strategy - store reference in context for downstream use
result_context = context.copy()
result_context.custom["merged_sources_dict"] = merged_features
result_context.custom["source_merge_applied"] = True
logger.info(
f"Source merge (dict) completed: {len(merged_features)} sources preserved"
)
else:
# Array strategies (concat/stack) - update dataset
processing_name = config.output_name
# Store as merged features
if isinstance(merged_features, np.ndarray):
dataset.add_merged_features(
features=merged_features,
processing_name=processing_name,
source=0 # Primary source for merged features
)
result_context = context.copy()
result_context.custom["source_merge_applied"] = True
if merged_features is not None:
shape_str = str(merged_features.shape) if hasattr(merged_features, 'shape') else 'dict'
logger.info(
f"Source merge ({config.strategy}) completed: shape={shape_str}"
)
# Build metadata
metadata = {
"merge_sources_strategy": config.strategy,
"sources_used": [source_names[i] for i in source_indices],
"source_indices": source_indices,
"n_sources_merged": len(source_indices),
"output_name": config.output_name,
# Store config for prediction mode
"source_merge_config": config.to_dict(),
**merge_info,
}
return result_context, StepOutput(metadata=metadata)
def _parse_source_merge_config(
self,
raw_config: Any
) -> SourceMergeConfig:
"""Parse source merge configuration.
Handles multiple syntax formats:
- Simple string: "concat", "stack", "dict"
- Dict with options: {"strategy": "stack", "sources": [...]}
- Already parsed SourceMergeConfig
Args:
raw_config: Raw configuration from step
Returns:
Normalized SourceMergeConfig instance
"""
if isinstance(raw_config, str):
# Simple strategy string
return SourceMergeConfig(strategy=raw_config)
elif isinstance(raw_config, dict):
# Dict configuration
return SourceMergeConfig(
strategy=raw_config.get("strategy", "concat"),
sources=raw_config.get("sources", "all"),
on_incompatible=raw_config.get("on_incompatible", "error"),
output_name=raw_config.get("output_name", "merged"),
preserve_source_info=raw_config.get("preserve_source_info", True),
)
elif isinstance(raw_config, SourceMergeConfig):
return raw_config
else:
raise ValueError(
f"Invalid merge_sources config type: {type(raw_config).__name__}. "
f"Expected string, dict, or SourceMergeConfig."
)
def _get_source_names(
self,
dataset: "SpectroDataset",
n_sources: int
) -> List[str]:
"""Get source names from dataset.
Args:
dataset: The dataset
n_sources: Number of sources
Returns:
List of source names (generates default names if not available)
"""
# Try to get source names from feature accessor
try:
source_names = []
for i in range(n_sources):
# Check if there's a name stored
processings = dataset.features_processings(i)
# Use first processing name or generate default
if processings:
source_names.append(f"source_{i}")
else:
source_names.append(f"source_{i}")
return source_names
except Exception:
return [f"source_{i}" for i in range(n_sources)]
def _collect_source_features(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
source_indices: List[int],
source_names: List[str],
) -> Tuple[Dict[int, np.ndarray], Dict[str, Any]]:
"""Collect features from specified sources.
Args:
dataset: The dataset
context: Execution context
source_indices: Which sources to collect
source_names: Names for logging
Returns:
Tuple of (source_features dict, info dict)
"""
source_features = {}
shapes = {}
for src_idx in source_indices:
try:
# Get features for this source
# Use concat_source=False to get per-source data
X = dataset.x(
selector=context.selector,
layout="2d",
concat_source=False,
include_augmented=True,
include_excluded=False
)
# X might be list or single array
if isinstance(X, list):
if src_idx < len(X):
features = X[src_idx]
else:
logger.warning(
f"Source index {src_idx} out of range "
f"(got {len(X)} sources). Skipping."
)
continue
else:
# Single source - only valid for index 0
if src_idx == 0:
features = X
else:
logger.warning(
f"Source index {src_idx} requested but dataset "
f"returned single array. Skipping."
)
continue
source_features[src_idx] = features
shapes[src_idx] = features.shape
logger.debug(
f"Collected source {src_idx} ({source_names[src_idx]}): "
f"shape={features.shape}"
)
except Exception as e:
logger.warning(
f"Failed to collect features from source {src_idx}: {e}"
)
continue
info = {
"source_shapes": shapes,
"sources_collected": len(source_features),
}
return source_features, info
def _merge_sources_concat(
self,
source_features: Dict[int, np.ndarray],
source_indices: List[int],
source_names: List[str],
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Merge sources by horizontal concatenation.
Args:
source_features: Dict mapping source index to feature array
source_indices: Source indices in order
source_names: Source names for logging
Returns:
Tuple of (merged 2D array, info dict)
"""
arrays = []
feature_counts = []
for src_idx in source_indices:
if src_idx in source_features:
arr = source_features[src_idx]
# Ensure 2D
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
elif arr.ndim > 2:
# Flatten to 2D
arr = arr.reshape(arr.shape[0], -1)
arrays.append(arr)
feature_counts.append(arr.shape[1])
if not arrays:
raise ValueError(
"No arrays to concatenate. All sources failed to collect. "
"[Error: MERGE-E030]"
)
# Validate sample counts match
sample_counts = [arr.shape[0] for arr in arrays]
if len(set(sample_counts)) > 1:
raise ValueError(
f"Sample count mismatch across sources: {sample_counts}. "
f"All sources must have the same number of samples. "
f"[Error: MERGE-E030]"
)
# Concatenate horizontally
merged = np.concatenate(arrays, axis=1)
info = {
"merged_shape": merged.shape,
"feature_counts_per_source": feature_counts,
"total_features": merged.shape[1],
}
return merged, info
def _merge_sources_stack(
self,
source_features: Dict[int, np.ndarray],
source_indices: List[int],
source_names: List[str],
on_incompatible: SourceIncompatibleStrategy,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Merge sources by stacking along new axis (3D result).
Args:
source_features: Dict mapping source index to feature array
source_indices: Source indices in order
source_names: Source names for logging
on_incompatible: How to handle shape mismatches
Returns:
Tuple of (merged 3D array or 2D fallback, info dict)
"""
arrays = []
feature_dims = []
for src_idx in source_indices:
if src_idx in source_features:
arr = source_features[src_idx]
# Ensure 2D first
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
elif arr.ndim > 2:
arr = arr.reshape(arr.shape[0], -1)
arrays.append(arr)
feature_dims.append(arr.shape[1])
if not arrays:
raise ValueError(
"No arrays to stack. All sources failed to collect. "
"[Error: MERGE-E030]"
)
# Check if shapes are compatible for stacking
shapes_compatible = len(set(feature_dims)) == 1
if not shapes_compatible:
logger.warning(
f"Source feature dimensions differ: {feature_dims}. "
f"Cannot stack directly (requires uniform dimensions)."
)
if on_incompatible == SourceIncompatibleStrategy.ERROR:
raise ValueError(
f"Cannot stack sources with different feature dimensions: {feature_dims}. "
f"Use on_incompatible='flatten' to fall back to 2D concat, "
f"or 'pad'/'truncate' to align dimensions. "
f"[Error: MERGE-E030]"
)
elif on_incompatible == SourceIncompatibleStrategy.FLATTEN:
logger.info("Falling back to 2D concatenation due to shape mismatch")
return self._merge_sources_concat(
source_features, source_indices, source_names
)
elif on_incompatible == SourceIncompatibleStrategy.PAD:
max_features = max(feature_dims)
padded_arrays = []
for arr in arrays:
if arr.shape[1] < max_features:
padding = np.zeros((arr.shape[0], max_features - arr.shape[1]))
arr = np.hstack([arr, padding])
padded_arrays.append(arr)
arrays = padded_arrays
logger.info(f"Padded all sources to {max_features} features")
elif on_incompatible == SourceIncompatibleStrategy.TRUNCATE:
min_features = min(feature_dims)
truncated_arrays = [arr[:, :min_features] for arr in arrays]
arrays = truncated_arrays
logger.info(f"Truncated all sources to {min_features} features")
# Stack along axis 1 to create (samples, sources, features)
merged = np.stack(arrays, axis=1)
info = {
"merged_shape": merged.shape,
"n_sources_stacked": len(arrays),
"features_per_source": arrays[0].shape[1] if arrays else 0,
"shape_adjustment": on_incompatible.value if not shapes_compatible else None,
}
return merged, info
def _merge_sources_dict(
self,
source_features: Dict[int, np.ndarray],
source_indices: List[int],
source_names: List[str],
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""Keep sources as structured dictionary.
Args:
source_features: Dict mapping source index to feature array
source_indices: Source indices in order
source_names: Source names for keys
Returns:
Tuple of (dict mapping source names to arrays, info dict)
"""
result = {}
shapes = {}
for src_idx in source_indices:
if src_idx in source_features:
name = source_names[src_idx]
arr = source_features[src_idx]
# Ensure 2D
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
elif arr.ndim > 2:
arr = arr.reshape(arr.shape[0], -1)
result[name] = arr
shapes[name] = arr.shape
info = {
"source_shapes": shapes,
"n_sources": len(result),
"output_format": "dict",
}
return result, info
def _execute_prediction_merge(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute prediction-only merge operation (late fusion).
Late fusion of predictions without branch context requirements.
Useful for combining predictions from multiple models without
requiring branch mode. Unlike `{"merge": "predictions"}` which
requires active branch contexts, this operates on the prediction
store directly.
Use cases:
- Combine predictions from sequential models (not in branches)
- Late fusion after separate model training phases
- Ensemble of predictions for final output
Args:
step_info: Parsed step containing merge configuration
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput)
"""
if prediction_store is None:
raise ValueError(
"merge_predictions requires prediction_store. "
"Ensure models were trained before this step. "
"[Error: MERGE-E010]"
)
# Parse configuration
raw_config = step_info.original_step.get("merge_predictions", {})
# Handle simple string vs dict config
if isinstance(raw_config, str):
if raw_config == "all":
model_filter = None
else:
model_filter = [raw_config]
aggregation = "separate"
elif isinstance(raw_config, dict):
model_filter = raw_config.get("models")
aggregation = raw_config.get("aggregate", "separate")
else:
model_filter = None
aggregation = "separate"
logger.info(
f"Prediction merge: models={model_filter or 'all'}, "
f"aggregate={aggregation}"
)
# Discover available models from prediction store
current_step = getattr(context.state, 'step_number', float('inf'))
filter_kwargs = {
'partition': 'val',
'load_arrays': False,
}
predictions = prediction_store.filter_predictions(**filter_kwargs)
predictions = [
p for p in predictions
if p.get('step_idx', 0) < current_step
]
available_models = sorted(set(
p.get('model_name') for p in predictions if p.get('model_name')
))
if not available_models:
raise ValueError(
"No model predictions found in prediction store. "
"Ensure models were trained before merge_predictions. "
"[Error: MERGE-E010]"
)
# Apply model filter
if model_filter:
selected_models = [m for m in available_models if m in model_filter]
if not selected_models:
logger.warning(
f"No models matched filter {model_filter}. "
f"Available: {available_models}"
)
else:
selected_models = available_models
logger.info(f" Selected {len(selected_models)} models for prediction merge")
# Collect predictions using OOF reconstruction
# Use empty branch_contexts since we're not in branch mode
config = MergeConfig(
collect_predictions=True,
prediction_branches="all",
model_filter=selected_models,
unsafe=False, # Always use OOF for safety
)
# Create synthetic branch context for the reconstructor
# This allows reuse of existing prediction collection logic
n_samples = dataset.num_samples
model_predictions: Dict[str, np.ndarray] = {}
for model_name in selected_models:
try:
from nirs4all.controllers.models.stacking import (
TrainingSetReconstructor,
ReconstructorConfig,
)
from nirs4all.operators.models.meta import StackingConfig, CoverageStrategy
stacking_config = StackingConfig(
coverage_strategy=CoverageStrategy.IMPUTE_MEAN,
)
reconstructor_config = ReconstructorConfig(
log_warnings=True,
validate_fold_alignment=False,
)
reconstructor = TrainingSetReconstructor(
prediction_store=prediction_store,
source_model_names=[model_name],
stacking_config=stacking_config,
reconstructor_config=reconstructor_config,
)
result = reconstructor.reconstruct(
dataset=dataset,
context=context,
use_proba=False,
)
# Combine train (OOF) and test predictions
combined = np.full(n_samples, np.nan)
# Get partition indices
# IMPORTANT: Use include_augmented=False for train because OOF predictions
# are only available for original (non-augmented) samples
train_context = context.with_partition('train')
train_ids = dataset._indexer.x_indices(
train_context.selector,
include_augmented=False,
include_excluded=False
)
test_context = context.with_partition('test')
test_ids = dataset._indexer.x_indices(
test_context.selector,
include_augmented=False,
include_excluded=False
)
# Fill train (OOF) predictions
if result.X_train_meta.size > 0:
train_preds = result.X_train_meta[:, 0] if result.X_train_meta.ndim > 1 else result.X_train_meta
if len(train_preds) == len(train_ids):
for i, sample_id in enumerate(train_ids):
combined[sample_id] = train_preds[i]
# Fill test predictions
if result.X_test_meta.size > 0:
test_preds = result.X_test_meta[:, 0] if result.X_test_meta.ndim > 1 else result.X_test_meta
if len(test_preds) == len(test_ids):
for i, sample_id in enumerate(test_ids):
combined[sample_id] = test_preds[i]
# Propagate predictions from base samples to their augmented versions
# Augmented samples should have the same prediction as their origin
base_sample_ids = list(train_ids) + list(test_ids)
if base_sample_ids:
augmented_ids = dataset._indexer._augmentation_tracker.get_augmented_for_origins(
base_sample_ids
)
for aug_id in augmented_ids:
origin_id = dataset._indexer._augmentation_tracker.get_origin_for_sample(aug_id)
if origin_id is not None and not np.isnan(combined[origin_id]):
combined[aug_id] = combined[origin_id]
model_predictions[model_name] = combined
logger.debug(f" Collected predictions from model '{model_name}'")
except Exception as e:
logger.warning(
f"Failed to collect predictions from model '{model_name}': {e}"
)
continue
if not model_predictions:
raise ValueError(
"Failed to collect predictions from any model. "
"[Error: MERGE-E010]"
)
# Aggregate predictions based on strategy
if aggregation == "separate":
merged = PredictionAggregator.aggregate(
predictions=model_predictions,
strategy=AggregationStrategy.SEPARATE,
)
elif aggregation == "mean":
merged = PredictionAggregator.aggregate(
predictions=model_predictions,
strategy=AggregationStrategy.MEAN,
)
elif aggregation == "weighted_mean":
# Get model scores for weighting (use validation scores from store)
model_selector = ModelSelector(
prediction_store=prediction_store,
context=context,
)
model_scores = model_selector.get_model_scores(
model_names=selected_models,
metric="rmse",
branch_id=-1, # No branch context
)
merged = PredictionAggregator.aggregate(
predictions=model_predictions,
strategy=AggregationStrategy.WEIGHTED_MEAN,
model_scores=model_scores,
metric="rmse",
)
else:
# Default to separate
merged = PredictionAggregator.aggregate(
predictions=model_predictions,
strategy=AggregationStrategy.SEPARATE,
)
# Store merged predictions as features
dataset.add_merged_features(
features=merged,
processing_name="merged_predictions",
source=0
)
# Build metadata
metadata = {
"merge_predictions": True,
"models_used": selected_models,
"aggregation": aggregation,
"n_features": merged.shape[1],
"merged_shape": merged.shape,
}
logger.success(
f"Prediction merge completed: {len(selected_models)} models, "
f"shape={merged.shape}"
)
return context.copy(), StepOutput(metadata=metadata)
# =============================================================================
# Phase 7: Static merge_branches method for MetaModel integration
# =============================================================================
[docs]
@classmethod
def merge_branches(
cls,
dataset: "SpectroDataset",
context: "ExecutionContext",
config: MergeConfig,
prediction_store: Optional[Any] = None,
mode: str = "train",
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Static method for programmatic merge (used by MetaModel).
This class method allows MetaModelController to delegate to merge logic
without going through the full step execution machinery. It provides
the core branch merging functionality without modifying the context
or requiring a step_info object.
This is the key integration point for Phase 7: MetaModel Refactoring.
Args:
dataset: SpectroDataset with sample data.
context: Execution context with branch_contexts and state.
config: MergeConfig specifying what to merge.
prediction_store: Prediction storage for model predictions.
Required if config.collect_predictions is True.
mode: Execution mode ("train" or "predict").
Returns:
Tuple of (merged_features, info_dict) where:
- merged_features: 2D numpy array (n_samples, n_features)
- info_dict: Dictionary with merge metadata including:
- "merged_shape": Shape of merged features
- "feature_branches_used": List of branch indices for features
- "prediction_branches_used": List of branch indices for predictions
- "models_used": List of model names (if predictions)
- "oof_reconstruction": Whether OOF was used (if predictions)
- "unsafe_merge": True if unsafe mode was used
Raises:
ValueError: If not in branch mode or config is invalid.
ValueError: If prediction_store is None but predictions requested.
Example:
>>> from nirs4all.controllers.data.merge import MergeController
>>> from nirs4all.operators.data.merge import MergeConfig
>>>
>>> # Called from MetaModelController
>>> config = MergeConfig(
... collect_predictions=True,
... prediction_branches="all",
... )
>>> merged_X, info = MergeController.merge_branches(
... dataset=dataset,
... context=context,
... config=config,
... prediction_store=prediction_store,
... )
>>> meta_model.fit(merged_X, y)
Note:
Unlike execute(), this method does NOT:
- Exit branch mode (caller must handle this if needed)
- Modify the context
- Add merged features to the dataset
- Return a StepOutput
It simply performs the merge computation and returns the result.
"""
# Create a controller instance for internal methods
controller = cls()
# Validate branch mode
branch_contexts = context.custom.get("branch_contexts", [])
in_branch_mode = context.custom.get("in_branch_mode", False)
if not branch_contexts and not in_branch_mode:
raise ValueError(
"merge_branches requires active branch contexts. "
"Use only after a branch step. "
"[Error: MERGE-E020]"
)
n_branches = len(branch_contexts)
# Validate branch indices in config
controller._validate_branches(config, branch_contexts)
# Log configuration
controller._log_config(
config=config,
n_branches=n_branches,
branch_contexts=branch_contexts,
prediction_store=prediction_store,
context=context,
)
merged_parts = []
info: Dict[str, Any] = {}
# Collect features if requested
if config.collect_features:
feature_branches = config.get_feature_branches(n_branches)
features_list, feature_info = controller._collect_features(
dataset=dataset,
branch_contexts=branch_contexts,
branch_indices=feature_branches,
on_missing=config.on_missing,
on_shape_mismatch=config.on_shape_mismatch,
)
if features_list:
merged_parts.extend(features_list)
info["feature_shapes"] = feature_info.get("shapes", [])
info["feature_branches_used"] = feature_info.get("branches_used", [])
logger.debug(
f"merge_branches: Collected features from {len(features_list)} branches"
)
# Collect predictions if requested
if config.collect_predictions:
predictions_array, pred_info = controller._collect_predictions(
dataset=dataset,
context=context,
branch_contexts=branch_contexts,
config=config,
prediction_store=prediction_store,
mode=mode,
)
if predictions_array is not None and predictions_array.size > 0:
merged_parts.append(predictions_array)
info["prediction_shape"] = predictions_array.shape
info["prediction_models_used"] = pred_info.get("models_used", [])
info["prediction_branches_used"] = pred_info.get("branches_used", [])
info["oof_reconstruction"] = pred_info.get("oof_reconstruction", True)
info["models_used"] = pred_info.get("models_used", [])
logger.debug(
f"merge_branches: Collected predictions: shape={predictions_array.shape}"
)
# Include original pre-branch features if requested
if config.include_original:
original_features = controller._get_original_features(dataset, context)
if original_features is not None:
merged_parts.insert(0, original_features)
info["include_original"] = True
info["original_shape"] = original_features.shape
# Concatenate all parts
if not merged_parts:
raise ValueError(
"merge_branches resulted in empty output - check configuration. "
"[Error: MERGE-E012]"
)
merged_features = np.concatenate(merged_parts, axis=1)
info["merged_shape"] = merged_features.shape
# Add unsafe warning if applicable
if config.unsafe:
info["unsafe_merge"] = True
logger.warning(
"⚠️ UNSAFE MERGE: OOF reconstruction disabled. "
"Training predictions used directly, causing DATA LEAKAGE."
)
logger.info(
f"merge_branches completed: shape={merged_features.shape}"
f"{' [UNSAFE]' if config.unsafe else ''}"
)
return merged_features, info
# Expose parser and utilities for testing
__all__ = [
"MergeController",
"MergeConfigParser",
"ModelSelector",
"PredictionAggregator",
"AsymmetricBranchAnalyzer",
"BranchAnalysisResult",
"AsymmetryReport",
"SourceMergeConfig",
# Phase 2: Disjoint sample branch merging
"DisjointBranchAnalysis",
"DisjointMergeResult",
"is_disjoint_branch",
"detect_disjoint_branches",
# Phase 3: Disjoint merge metadata
"DisjointBranchInfo",
"DisjointMergeMetadata",
]