nirs4all.controllers.data.merge module
Merge Controller for branch combination and exit.
This controller is the CORE PRIMITIVE for all branch combination operations. It handles: 1. Exiting branch mode (always, unconditionally) 2. Collecting features and/or predictions from branches 3. Enforcing OOF (out-of-fold) safety when predictions are involved 4. Creating a unified dataset for subsequent steps
Phase 1 Implementation: - Controller registration and matching - Configuration parsing for all syntax variants - Branch validation utilities
Phase 3 Implementation: - Feature collection and concatenation - Shape mismatch handling
Phase 4 Implementation: - Model discovery from prediction store - OOF prediction reconstruction via TrainingSetReconstructor - Unsafe mode with prominent warnings - Simple prediction merge syntax
Phase 5 Implementation: - Per-branch model selection strategies (all, best, top_k, explicit) - Per-branch aggregation strategies (separate, mean, weighted_mean, proba_mean) - Model ranking by validation metrics - Advanced per-branch prediction configuration
Phase 6 Implementation: - Mixed merging (features from some branches, predictions from others) - Asymmetric branch detection and handling (models in some, not others) - Different feature dimensions per branch handling - Different model counts per branch handling - Improved error messages with resolution suggestions (MERGE-E010, MERGE-E011)
Phase 8 Implementation: - Prediction mode support for merge steps - Bundle export support - Full train/predict cycle
Phase 9 Implementation: - Source merge (merge_sources keyword) for multi-source datasets - Source merge strategies: concat, stack, dict - Source incompatibility handling: error, flatten, pad, truncate - Prediction merge (merge_predictions keyword) for late fusion - Error codes: MERGE-E024, MERGE-E030, MERGE-E031
Example
>>> # Simple feature merge
>>> pipeline = [
... {"branch": [[SNV()], [MSC()]]},
... {"merge": "features"},
... PLSRegression(n_components=10)
... ]
>>>
>>> # Prediction stacking
>>> pipeline = [
... {"branch": [[SNV(), PLS()], [MSC(), RF()]]},
... {"merge": "predictions"},
... {"model": Ridge()}
... ]
>>>
>>> # Source merge for multi-source datasets
>>> pipeline = [
... SNV(), # Applied to all sources
... {"merge_sources": "concat"}, # Combine NIR + markers
... {"model": PLS()}
... ]
>>>
>>> # Late fusion without branches
>>> pipeline = [
... SNV(),
... {"model": PLS()},
... {"model": RF()},
... {"merge_predictions": "all"}, # Combine predictions
... {"model": Ridge()}
... ]
Keywords: “merge”, “merge_sources”, “merge_predictions” Priority: 5 (same as BranchController)
- class nirs4all.controllers.data.merge.AsymmetricBranchAnalyzer(branch_contexts: List[Dict[str, Any]], prediction_store: Any | None, context: ExecutionContext)[source]
Bases:
objectUtility class for analyzing branch asymmetry.
Detects and reports on asymmetry across branches, providing detailed information for error messages and resolution suggestions.
Phase 6 Features: - Detect model presence asymmetry (some have models, some don’t) - Detect model count asymmetry (different numbers of models) - Detect feature dimension asymmetry - Generate resolution suggestions for mixed merge
- analyze_all() AsymmetryReport[source]
Analyze all branches for asymmetry.
- Returns:
AsymmetryReport with comprehensive asymmetry analysis.
- analyze_branch(branch_idx: int) BranchAnalysisResult[source]
Analyze a single branch for its characteristics.
- Parameters:
branch_idx – Branch index to analyze.
- Returns:
BranchAnalysisResult with branch characteristics.
- class nirs4all.controllers.data.merge.AsymmetryReport(is_asymmetric: bool, has_model_asymmetry: bool, has_model_count_asymmetry: bool, has_feature_dim_asymmetry: bool, branches_with_models: List[int], branches_without_models: List[int], model_counts: Dict[int, int], feature_dims: Dict[int, int | None], summary: str)[source]
Bases:
objectReport on asymmetry across branches.
Provides detailed analysis of how branches differ, helping users understand and resolve merge configuration issues.
- class nirs4all.controllers.data.merge.BranchAnalysisResult(branch_id: int, branch_name: str | None, has_models: bool, model_names: List[str], model_count: int, feature_dim: int | None, has_features: bool)[source]
Bases:
objectResult of analyzing branch asymmetry.
- class nirs4all.controllers.data.merge.DisjointBranchAnalysis(is_disjoint: bool, branch_type: BranchType | None, branch_sample_counts: Dict[int, int], branch_sample_indices: Dict[int, List[int]], total_samples: int, partition_column: str | None = None)[source]
Bases:
objectAnalysis result for disjoint sample branches.
- branch_type
Type of disjoint branching (metadata_partitioner, sample_partitioner).
- Type:
- branch_type: BranchType | None
- class nirs4all.controllers.data.merge.DisjointBranchInfo(n_samples: int, sample_ids: ~typing.List[int], n_models_original: int = 0, n_models_selected: int = 0, selected_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>, dropped_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>)[source]
Bases:
objectInformation about a single branch in a disjoint merge.
Captures per-branch statistics and model selection details for comprehensive merge metadata.
- class nirs4all.controllers.data.merge.DisjointMergeMetadata(merge_type: str = 'disjoint_samples', n_columns: int = 0, select_by: str = 'mse', branches: Dict[str, ~nirs4all.operators.data.merge.DisjointBranchInfo]=<factory>, column_mapping: Dict[int, ~typing.Dict[str, str]]=<factory>, is_heterogeneous: bool = False, feature_dim: int | None = None)[source]
Bases:
objectComplete metadata for a disjoint sample branch merge.
This dataclass captures all information about a disjoint merge operation for logging, debugging, and downstream use. Matches the specification in docs/reports/disjoint_sample_branch_merging.md Section 6.
- branches
Per-branch information as Dict[branch_name, DisjointBranchInfo].
- Type:
- column_mapping
Maps output column index to per-branch model names. Example: {0: {“red”: “RF”, “blue”: “PLS”}, 1: {“red”: “PLS”, “blue”: “RF”}}
Example
>>> metadata = DisjointMergeMetadata( ... merge_type="disjoint_samples", ... n_columns=2, ... select_by="mse", ... branches={ ... "red": DisjointBranchInfo(n_samples=50, sample_ids=[...], ...), ... "blue": DisjointBranchInfo(n_samples=100, sample_ids=[...], ...), ... }, ... column_mapping={ ... 0: {"red": "RF", "blue": "PLS"}, ... 1: {"red": "PLS", "blue": "RF"}, ... }, ... )
- branches: Dict[str, DisjointBranchInfo]
- classmethod from_dict(data: Dict[str, Any]) DisjointMergeMetadata[source]
Create from dictionary representation.
- Parameters:
data – Dictionary with metadata fields.
- Returns:
DisjointMergeMetadata instance.
- get_branch_summary() str[source]
Get a summary string for logging.
- Returns:
Human-readable summary of branch statistics.
- get_column_mapping_summary() List[str][source]
Get column mapping summary for logging.
- Returns:
List of strings describing each column’s model mapping.
- log_summary(logger_func) None[source]
Log merge summary using provided logger function.
- Parameters:
logger_func – Logger function (e.g., logger.info)
- class nirs4all.controllers.data.merge.DisjointMergeResult(merged_array: ndarray, n_columns: int, select_by: str, branch_info: Dict[str, Any], column_mapping: Dict[int, Dict[str, str]])[source]
Bases:
objectResult of disjoint sample branch merge.
- merged_array
The merged prediction or feature array (n_total_samples, n_columns).
- Type:
- class nirs4all.controllers.data.merge.MergeConfigParser[source]
Bases:
objectParser for merge step configurations.
Handles all syntax variants and normalizes them to MergeConfig.
- Supported syntaxes:
Simple string: “features”, “predictions”, “all”
Dict with keys: {“features”: …, “predictions”: …, …}
Legacy format: {“predictions”: [0, 1]}
Per-branch format: {“predictions”: [{“branch”: 0, …}]}
- classmethod parse(raw_config: Any) MergeConfig[source]
Parse raw merge configuration into MergeConfig.
- Parameters:
raw_config – The value from {“merge”: raw_config}
- Returns:
Normalized MergeConfig instance.
- Raises:
ValueError – If configuration format is invalid.
- class nirs4all.controllers.data.merge.MergeController[source]
Bases:
OperatorControllerController for merging branch outputs and exiting branch mode.
This controller is the CORE PRIMITIVE for branch combination. It: 1. Collects features and/or predictions from specified branches 2. Performs horizontal concatenation of features 3. Performs OOF reconstruction for predictions (mandatory unless unsafe=True) 4. Creates a unified “merged” processing in the dataset 5. ALWAYS clears branch contexts and exits branch mode
- Supported Keywords:
“merge”: Branch merging (features/predictions/both)
“merge_sources”: Source merging (multi-source datasets) [Phase 9]
“merge_predictions”: Prediction-only late fusion [Phase 9]
- OOF Safety:
When predictions are merged, OOF reconstruction is MANDATORY by default. This prevents data leakage when the merged output is used for training. Set unsafe=True to disable OOF (generates prominent warnings).
- Relationship to MetaModel:
MetaModel internally uses MergeController for data preparation, then trains the meta-learner. Users can achieve the same result with:
{“merge”: “predictions”}, {“model”: Ridge()}
- which is equivalent to:
{“model”: MetaModel(Ridge())}
- SUPPORTED_KEYWORDS
Set of keywords this controller handles.
- SUPPORTED_KEYWORDS = {'merge', 'merge_predictions', 'merge_sources'}
- classmethod build_config_from_meta_model(meta_operator: Any, context: ExecutionContext, branch_contexts: List[Dict[str, Any]] | None = None) MergeConfig[source]
Build MergeConfig from MetaModel operator parameters.
Translates MetaModel configuration to an equivalent MergeConfig for use with merge_branches(). This enables MetaModel to delegate to the centralized merge logic.
This is a helper for Phase 7: MetaModel Refactoring.
- Parameters:
meta_operator – MetaModel operator instance with configuration.
context – Execution context with branch info.
branch_contexts – Optional branch contexts for branch resolution.
- Returns:
MergeConfig equivalent to the MetaModel’s configuration.
Example
>>> config = MergeController.build_config_from_meta_model( ... meta_operator=meta_model, ... context=context, ... ) >>> merged_X, info = MergeController.merge_branches( ... dataset=dataset, ... context=context, ... config=config, ... prediction_store=prediction_store, ... )
- execute(step_info: ParsedStep, dataset: SpectroDataset, context: ExecutionContext, runtime_context: RuntimeContext, source: int = -1, mode: str = 'train', loaded_binaries: List[Tuple[str, Any]] | None = None, prediction_store: Any | None = None) Tuple[ExecutionContext, StepOutput][source]
Execute the merge step with keyword dispatch.
Dispatches to appropriate handler based on the step keyword: - “merge”: Branch merging (features/predictions/both) - “merge_sources”: Source merging (Phase 9, not yet implemented) - “merge_predictions”: Prediction-only late fusion (Phase 9, not yet implemented)
Phase 2 implementation provides: - Configuration parsing - Branch validation - Branch mode exit - Keyword dispatch framework
Subsequent phases will add: - Feature collection (Phase 3) - Prediction OOF reconstruction (Phase 4) - Per-branch selection/aggregation (Phase 5) - Source merge implementation (Phase 9)
- Parameters:
step_info – Parsed step containing merge configuration
dataset – Dataset to operate on
context – Pipeline execution context
runtime_context – Runtime infrastructure context
source – Data source index
mode – Execution mode (“train” or “predict”)
loaded_binaries – Pre-loaded binary objects for prediction mode
prediction_store – External prediction store for model predictions
- Returns:
Tuple of (updated_context, StepOutput)
- Raises:
ValueError – If not in branch mode or configuration is invalid.
NotImplementedError – If merge_sources or merge_predictions called (Phase 9).
- classmethod matches(step: Any, operator: Any, keyword: str) bool[source]
Check if the step matches the merge controller.
- Parameters:
step – Original step configuration
operator – Deserialized operator
keyword – Step keyword
- Returns:
True if keyword is one of the supported merge keywords.
- classmethod merge_branches(dataset: SpectroDataset, context: ExecutionContext, config: MergeConfig, prediction_store: Any | None = None, mode: str = 'train') Tuple[ndarray, Dict[str, Any]][source]
Static method for programmatic merge (used by MetaModel).
This class method allows MetaModelController to delegate to merge logic without going through the full step execution machinery. It provides the core branch merging functionality without modifying the context or requiring a step_info object.
This is the key integration point for Phase 7: MetaModel Refactoring.
- Parameters:
dataset – SpectroDataset with sample data.
context – Execution context with branch_contexts and state.
config – MergeConfig specifying what to merge.
prediction_store – Prediction storage for model predictions. Required if config.collect_predictions is True.
mode – Execution mode (“train” or “predict”).
- Returns:
merged_features: 2D numpy array (n_samples, n_features)
- info_dict: Dictionary with merge metadata including:
”merged_shape”: Shape of merged features
”feature_branches_used”: List of branch indices for features
”prediction_branches_used”: List of branch indices for predictions
”models_used”: List of model names (if predictions)
”oof_reconstruction”: Whether OOF was used (if predictions)
”unsafe_merge”: True if unsafe mode was used
- Return type:
Tuple of (merged_features, info_dict) where
- Raises:
ValueError – If not in branch mode or config is invalid.
ValueError – If prediction_store is None but predictions requested.
Example
>>> from nirs4all.controllers.data.merge import MergeController >>> from nirs4all.operators.data.merge import MergeConfig >>> >>> # Called from MetaModelController >>> config = MergeConfig( ... collect_predictions=True, ... prediction_branches="all", ... ) >>> merged_X, info = MergeController.merge_branches( ... dataset=dataset, ... context=context, ... config=config, ... prediction_store=prediction_store, ... ) >>> meta_model.fit(merged_X, y)
Note
Unlike execute(), this method does NOT: - Exit branch mode (caller must handle this if needed) - Modify the context - Add merged features to the dataset - Return a StepOutput
It simply performs the merge computation and returns the result.
- class nirs4all.controllers.data.merge.ModelSelector(prediction_store: Predictions, context: ExecutionContext)[source]
Bases:
objectUtility class for selecting models based on validation metrics.
Handles model ranking and selection strategies (all, best, top_k, explicit) for per-branch prediction collection and stacking operations.
This class is shared between MergeController and MetaModelController to avoid code duplication.
- prediction_store
Prediction storage instance.
- context
Execution context.
- LOWER_IS_BETTER_METRICS
Set of metrics where lower values are better.
- LOWER_IS_BETTER_METRICS = {'log_loss', 'mae', 'mape', 'mse', 'nmae', 'nmse', 'nrmse', 'rmse'}
- get_model_scores(model_names: List[str], metric: str, branch_id: int) Dict[str, float][source]
Get validation scores for multiple models.
Used for weighted aggregation.
- Parameters:
model_names – List of model names.
metric – Metric name.
branch_id – Branch identifier.
- Returns:
Dictionary mapping model name to score.
- select_models(available_models: List[str], config: BranchPredictionConfig, branch_id: int) List[str][source]
Select models from available models based on config.
- Parameters:
available_models – List of available model names in the branch.
config – Per-branch prediction configuration.
branch_id – Branch identifier.
- Returns:
List of selected model names.
- Raises:
ValueError – If explicit model selection references unknown models.
- select_models_global(available_models: List[str], selection: Any, metric: str | None = None) List[str][source]
Select models globally (without branch context).
This is used by MetaModelController for pipelines without branches.
- Parameters:
available_models – List of available model names.
selection – Selection configuration: - “all”: Use all models - “best”: Use best model - {“top_k”: N}: Use top N models - [“model1”, “model2”]: Explicit list
metric – Optional metric for ranking.
- Returns:
List of selected model names.
- class nirs4all.controllers.data.merge.PredictionAggregator[source]
Bases:
objectUtility class for aggregating predictions from multiple models.
Handles aggregation strategies (separate, mean, weighted_mean, proba_mean) for combining predictions within a branch or across models.
This class is shared between MergeController and MetaModelController to avoid code duplication.
All methods are static as no instance state is needed.
- LOWER_IS_BETTER_METRICS = {'log_loss', 'mae', 'mape', 'mse', 'nmae', 'nmse', 'nrmse', 'rmse'}
- static aggregate(predictions: Dict[str, ndarray], strategy: AggregationStrategy, model_scores: Dict[str, float] | None = None, proba: bool = False, metric: str | None = None) ndarray[source]
Aggregate predictions from multiple models.
- Parameters:
predictions – Dictionary mapping model names to prediction arrays. Each array has shape (n_samples,) for regression or (n_samples, n_classes) for classification probabilities.
strategy – Aggregation strategy to use.
model_scores – Optional dictionary of model scores for weighted averaging.
proba – Whether predictions are class probabilities.
metric – Metric name (for determining weight direction).
- Returns:
SEPARATE: (n_samples, n_models)
MEAN/WEIGHTED_MEAN: (n_samples, 1)
PROBA_MEAN: (n_samples, n_classes)
- Return type:
Aggregated predictions with shape
- Raises:
ValueError – If predictions dict is empty.
- static aggregate_folds(fold_predictions: List[ndarray], fold_scores: List[float] | None = None, strategy: str = 'mean', metric: str | None = None) ndarray[source]
Aggregate predictions across CV folds.
Useful for combining test predictions from different folds.
- Parameters:
fold_predictions – List of prediction arrays, one per fold.
fold_scores – Optional list of validation scores per fold.
strategy – Aggregation strategy (“mean”, “weighted_mean”, “best”).
metric – Metric name for weighted aggregation.
- Returns:
Aggregated predictions.
- class nirs4all.controllers.data.merge.SourceMergeConfig(strategy: str = 'concat', sources: str | List[int | str] = 'all', on_incompatible: str = 'error', output_name: str = 'merged', preserve_source_info: bool = True)[source]
Bases:
objectConfiguration for merging multi-source dataset features.
This dataclass provides configuration for the merge_sources keyword, which combines features from multiple data sources (e.g., NIR, markers, Raman) into a unified feature space.
Unlike branch merging (merge), source merging operates on the data provenance dimension—combining features that originated from different sensors, instruments, or data modalities.
- strategy
How to combine source features. - “concat” (default): Horizontal concatenation (2D result) - “stack”: Stack along new axis (3D result, requires uniform shapes) - “dict”: Keep as structured dictionary (for multi-input models)
- Type:
- sources
Which sources to include. - “all” (default): Include all available sources - List of source indices: [0, 1] for specific sources - List of source names: [“NIR”, “markers”] for named sources
- on_incompatible
How to handle incompatible shapes (for stack strategy). - “error” (default): Raise error if shapes don’t match - “flatten”: Fall back to 2D concat - “pad”: Pad shorter with zeros - “truncate”: Truncate longer to match shortest
- Type:
Example
>>> # Simple concatenation (default) >>> {"merge_sources": "concat"} >>> >>> # Stack for 3D models (requires same feature count per source) >>> {"merge_sources": {"strategy": "stack"}} >>> >>> # Selective sources with fallback on shape mismatch >>> {"merge_sources": { ... "strategy": "stack", ... "sources": ["NIR", "MIR"], ... "on_incompatible": "flatten" ... }} >>> >>> # Dict output for multi-head models >>> {"merge_sources": {"strategy": "dict"}}
- classmethod from_dict(data: Dict[str, Any]) SourceMergeConfig[source]
Create config from dictionary.
- Parameters:
data – Dictionary representation.
- Returns:
SourceMergeConfig instance.
- get_incompatible_strategy() SourceIncompatibleStrategy[source]
Get the incompatible handling strategy as an enum.
- Returns:
SourceIncompatibleStrategy enum value.
- get_source_indices(available_sources: List[str]) List[int][source]
Resolve source specification to indices.
- Parameters:
available_sources – List of available source names.
- Returns:
List of source indices to include.
- Raises:
ValueError – If a specified source is not found.
- get_strategy() SourceMergeStrategy[source]
Get the merge strategy as an enum.
- Returns:
SourceMergeStrategy enum value.
- nirs4all.controllers.data.merge.detect_disjoint_branches(branch_contexts: List[Dict[str, Any]]) DisjointBranchAnalysis[source]
Detect if branches represent disjoint sample partitions.
Examines branch contexts to determine if they were created by a partitioning controller (metadata_partitioner or sample_partitioner).
- Parameters:
branch_contexts – List of branch context dictionaries.
- Returns:
DisjointBranchAnalysis with detection results.
- nirs4all.controllers.data.merge.is_disjoint_branch(branch_context: Dict[str, Any]) bool[source]
Check if a branch context indicates disjoint sample branching.
A disjoint branch has a ‘sample_partition’ or ‘partition_info’ key that indicates samples were partitioned (not copied) across branches.
- Parameters:
branch_context – A single branch context dictionary.
- Returns:
True if this branch is part of a disjoint sample partition.