nirs4all.operators.data.merge module
Merge operator configuration for branch and source merging.
This module provides configuration dataclasses and enums for the MergeController, which handles combining branch outputs (features and/or predictions) and exiting branch mode.
The merge operator is the core primitive for all branch combination operations. It provides: - Feature merging from branches (horizontal concatenation) - Prediction merging with OOF reconstruction (data leakage prevention) - Per-branch model selection and aggregation strategies - Mixed merging (features from some branches, predictions from others)
Example
>>> # Simple feature merge
>>> {"merge": "features"}
>>>
>>> # Prediction merge with OOF safety
>>> {"merge": "predictions"}
>>>
>>> # Mixed merge with per-branch control
>>> {"merge": {
... "predictions": [
... {"branch": 0, "select": "best", "metric": "rmse"},
... {"branch": 1, "aggregate": "mean"}
... ],
... "features": [2]
... }}
- class nirs4all.operators.data.merge.AggregationStrategy(value)[source]
Bases:
EnumHow to aggregate predictions from selected models within a branch.
After model selection, this controls how the selected predictions are combined into features for the merged output.
- SEPARATE
Keep each model’s predictions as separate features (default). Results in N features (one per selected model).
- MEAN
Simple average of all selected model predictions. Results in 1 feature.
- WEIGHTED_MEAN
Weighted average by validation score. Results in 1 feature.
- PROBA_MEAN
Average class probabilities (classification only). Results in K features (one per class).
- MEAN = 'mean'
- PROBA_MEAN = 'proba_mean'
- SEPARATE = 'separate'
- WEIGHTED_MEAN = 'weighted_mean'
- class nirs4all.operators.data.merge.BranchPredictionConfig(branch: int | str, select: str | Dict[str, Any] | List[str] = 'all', metric: str | None = None, aggregate: str = 'separate', weight_metric: str | None = None, proba: bool = False, sources: str | List[int | str] = 'all')[source]
Bases:
objectConfiguration for prediction collection from a single branch.
This dataclass specifies how to collect and process predictions from a specific branch during merge operations.
- select
Model selection strategy. - “all” (default): All models in branch - “best”: Single best model by metric - {“top_k”: N}: Top N models by metric - [“ModelA”, “ModelB”]: Explicit model names
- metric
Metric for selection (rmse, mae, r2, accuracy, f1). Default is task-appropriate (rmse for regression, accuracy for classification).
- Type:
str | None
- aggregate
How to combine predictions from selected models. - “separate” (default): Each model is a separate feature - “mean”: Simple average of predictions - “weighted_mean”: Weight by validation score - “proba_mean”: Average class probabilities (classification)
- Type:
- sources
Source filter for multi-source datasets. - “all” (default): Include all sources - List of source indices or names
Example
>>> # Best model from branch 0 by RMSE >>> BranchPredictionConfig(branch=0, select="best", metric="rmse") >>> >>> # Top 3 models from branch 1, averaged >>> BranchPredictionConfig( ... branch=1, ... select={"top_k": 3}, ... metric="r2", ... aggregate="mean" ... ) >>> >>> # Explicit models with weighted average >>> BranchPredictionConfig( ... branch="spectral_path", ... select=["PLS", "RF"], ... aggregate="weighted_mean", ... weight_metric="r2" ... )
- get_aggregation_strategy() AggregationStrategy[source]
Get the aggregation strategy enum for this configuration.
- Returns:
AggregationStrategy enum value based on aggregate field.
- get_selection_strategy() SelectionStrategy[source]
Get the selection strategy enum for this configuration.
- Returns:
SelectionStrategy enum value based on select field.
- class nirs4all.operators.data.merge.BranchType(value)[source]
Bases:
EnumType of branch based on sample handling.
- COPY
All branches see all samples (default branching behavior).
- METADATA_PARTITIONER
Branches partition samples by metadata column.
- SAMPLE_PARTITIONER
Branches partition samples by filter (e.g., outlier).
- COPY = 'copy'
- METADATA_PARTITIONER = 'metadata_partitioner'
- SAMPLE_PARTITIONER = 'sample_partitioner'
- class nirs4all.operators.data.merge.DisjointBranchInfo(n_samples: int, sample_ids: ~typing.List[int], n_models_original: int = 0, n_models_selected: int = 0, selected_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>, dropped_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>)[source]
Bases:
objectInformation about a single branch in a disjoint merge.
Captures per-branch statistics and model selection details for comprehensive merge metadata.
- class nirs4all.operators.data.merge.DisjointMergeMetadata(merge_type: str = 'disjoint_samples', n_columns: int = 0, select_by: str = 'mse', branches: ~typing.Dict[str, ~nirs4all.operators.data.merge.DisjointBranchInfo] = <factory>, column_mapping: ~typing.Dict[int, ~typing.Dict[str, str]] = <factory>, is_heterogeneous: bool = False, feature_dim: int | None = None)[source]
Bases:
objectComplete metadata for a disjoint sample branch merge.
This dataclass captures all information about a disjoint merge operation for logging, debugging, and downstream use. Matches the specification in docs/reports/disjoint_sample_branch_merging.md Section 6.
- branches
Per-branch information as Dict[branch_name, DisjointBranchInfo].
- Type:
- column_mapping
Maps output column index to per-branch model names. Example: {0: {“red”: “RF”, “blue”: “PLS”}, 1: {“red”: “PLS”, “blue”: “RF”}}
Example
>>> metadata = DisjointMergeMetadata( ... merge_type="disjoint_samples", ... n_columns=2, ... select_by="mse", ... branches={ ... "red": DisjointBranchInfo(n_samples=50, sample_ids=[...], ...), ... "blue": DisjointBranchInfo(n_samples=100, sample_ids=[...], ...), ... }, ... column_mapping={ ... 0: {"red": "RF", "blue": "PLS"}, ... 1: {"red": "PLS", "blue": "RF"}, ... }, ... )
- branches: Dict[str, DisjointBranchInfo]
- classmethod from_dict(data: Dict[str, Any]) DisjointMergeMetadata[source]
Create from dictionary representation.
- Parameters:
data – Dictionary with metadata fields.
- Returns:
DisjointMergeMetadata instance.
- get_branch_summary() str[source]
Get a summary string for logging.
- Returns:
Human-readable summary of branch statistics.
- get_column_mapping_summary() List[str][source]
Get column mapping summary for logging.
- Returns:
List of strings describing each column’s model mapping.
- log_summary(logger_func) None[source]
Log merge summary using provided logger function.
- Parameters:
logger_func – Logger function (e.g., logger.info)
- class nirs4all.operators.data.merge.DisjointSelectionCriterion(value)[source]
Bases:
EnumCriterion for selecting top-N models in disjoint branch merge.
When branches have different model counts, we select top-N models from each branch based on this criterion.
- MSE
Select by lowest Mean Squared Error (default for regression).
- RMSE
Select by lowest Root Mean Squared Error.
- MAE
Select by lowest Mean Absolute Error.
- R2
Select by highest R² score.
- ORDER
Select first N in definition order (no ranking).
- MAE = 'mae'
- MSE = 'mse'
- ORDER = 'order'
- R2 = 'r2'
- RMSE = 'rmse'
- class nirs4all.operators.data.merge.MergeConfig(collect_features: bool = False, feature_branches: str | List[int] = 'all', collect_predictions: bool = False, prediction_branches: str | List[int] = 'all', prediction_configs: List[BranchPredictionConfig] | None = None, model_filter: List[str] | None = None, use_proba: bool = False, include_original: bool = False, on_missing: str = 'error', on_shape_mismatch: str = 'error', unsafe: bool = False, output_as: str = 'features', source_names: List[str] | None = None, n_columns: int | None = None, select_by: str = 'mse')[source]
Bases:
objectConfiguration for branch merging operations.
This dataclass provides complete configuration for the MergeController, controlling what data is collected from branches and how it is combined.
- feature_branches
Which branches to collect features from. - “all” (default): All branches - List of branch indices: [0, 2] for specific branches
- prediction_branches
Legacy simple mode: which branches for predictions. Use prediction_configs for advanced per-branch control.
- prediction_configs
Advanced per-branch prediction configuration. Takes precedence over prediction_branches when set.
- Type:
List[nirs4all.operators.data.merge.BranchPredictionConfig] | None
- model_filter
Legacy: global model filter (simple mode). List of model names to include.
- Type:
List[str] | None
- include_original
Include pre-branch features in merged output. When True, original features are prepended to merged features.
- Type:
- on_missing
How to handle missing branches or predictions. - “error” (default): Raise an error - “warn”: Log warning and skip - “skip”: Silent skip
- Type:
- on_shape_mismatch
Reserved for 3D layout feature merging. In 2D layout (default), features are flattened and concatenated horizontally, so different feature dimensions is normal and this parameter has no effect. For future 3D layout support: - “error”: Raise error if processings differ - “allow”: Flatten to 2D and concatenate - “pad”: Pad shorter processings with zeros - “truncate”: Truncate longer to match shortest
- Type:
- unsafe
If True, DISABLE OOF reconstruction for predictions. ⚠️ CAUSES DATA LEAKAGE - only for rapid prototyping.
- Type:
- output_as
Where to put merged output. - “features” (default): Single concatenated feature matrix - “sources”: Each branch becomes a separate source - “dict”: Keep as structured dict for multi-head models
- Type:
- source_names
Custom names for output sources (when output_as=”sources”). If not provided, uses “branch_0”, “branch_1”, etc.
- Type:
List[str] | None
Example
>>> # Simple feature merge >>> MergeConfig(collect_features=True) >>> >>> # Prediction merge with OOF >>> MergeConfig(collect_predictions=True) >>> >>> # Mixed merge with per-branch control >>> MergeConfig( ... collect_predictions=True, ... prediction_configs=[ ... BranchPredictionConfig(branch=0, select="best"), ... BranchPredictionConfig(branch=1, aggregate="mean") ... ], ... collect_features=True, ... feature_branches=[2] ... ) >>> >>> # Unsafe mode (with warning) >>> MergeConfig(collect_predictions=True, unsafe=True) >>> >>> # Disjoint branch merge with n_columns override >>> MergeConfig( ... collect_predictions=True, ... n_columns=2, ... select_by="mse" ... )
- classmethod from_dict(data: Dict[str, Any]) MergeConfig[source]
Create MergeConfig from a dictionary.
Used for loading merge configuration from manifest in prediction mode.
- Parameters:
data – Dictionary representation of merge configuration.
- Returns:
MergeConfig instance.
- get_feature_branches(n_branches: int) List[int][source]
Get list of branch indices to collect features from.
- Parameters:
n_branches – Total number of branches available.
- Returns:
List of branch indices.
- get_merge_mode() MergeMode[source]
Determine the merge mode based on configuration.
- Returns:
MergeMode enum value.
- get_prediction_configs(n_branches: int) List[BranchPredictionConfig][source]
Get prediction configurations, normalizing legacy format if needed.
Converts legacy simple mode (prediction_branches + model_filter + use_proba) to per-branch configurations for uniform processing.
- Parameters:
n_branches – Total number of branches available.
- Returns:
List of BranchPredictionConfig for each branch to collect from.
- get_selection_criterion() DisjointSelectionCriterion[source]
Get the selection criterion enum for disjoint branch merging.
- Returns:
DisjointSelectionCriterion enum value.
- get_shape_mismatch_strategy() ShapeMismatchStrategy[source]
Get the shape mismatch strategy enum.
- Returns:
ShapeMismatchStrategy enum value.
- has_per_branch_config() bool[source]
Check if using advanced per-branch prediction configuration.
- Returns:
True if prediction_configs is set and non-empty.
- prediction_configs: List[BranchPredictionConfig] | None = None
- class nirs4all.operators.data.merge.MergeMode(value)[source]
Bases:
EnumWhat to merge from branches.
- FEATURES
Merge feature matrices from branches.
- PREDICTIONS
Merge model predictions from branches (with OOF reconstruction).
- ALL
Merge both features and predictions from all branches.
- ALL = 'all'
- FEATURES = 'features'
- PREDICTIONS = 'predictions'
- class nirs4all.operators.data.merge.SelectionStrategy(value)[source]
Bases:
EnumHow to select models within a branch for prediction merging.
When a branch contains multiple models, this controls which models’ predictions are included in the merge.
- ALL
Include all models in the branch (default).
- BEST
Single best model by specified metric.
- TOP_K
Top K models by specified metric.
- EXPLICIT
Explicit list of model names.
- ALL = 'all'
- BEST = 'best'
- EXPLICIT = 'explicit'
- TOP_K = 'top_k'
- class nirs4all.operators.data.merge.ShapeMismatchStrategy(value)[source]
Bases:
EnumHow to handle shape mismatches during 3D feature merging.
This strategy only applies when using 3D layout for features, where the number of processings must be aligned across branches. In 2D layout (the default), features are simply flattened and concatenated horizontally, so different feature dimensions across branches is expected and normal.
Example
Branch 0: (200 samples, 500 features) from MinMaxScaler
Branch 1: (200 samples, 4 processings, 20 features) from multi-processing
In 2D layout: concatenates to (200, 500 + 4*20 = 580) - no error In 3D layout: needs alignment strategy since processings differ
- ERROR
Raise an error on shape mismatch (default, strictest).
- ALLOW
Flatten to 2D and concatenate regardless of differences.
- PAD
Pad shorter branches with zeros to match longest processings.
- TRUNCATE
Truncate longer branches to match shortest processings.
- ALLOW = 'allow'
- ERROR = 'error'
- PAD = 'pad'
- TRUNCATE = 'truncate'
- class nirs4all.operators.data.merge.SourceBranchConfig(source_pipelines: str | ~typing.Dict[str | int, ~typing.List[~typing.Any]] = <factory>, default_pipeline: ~typing.List[~typing.Any] | None = None, merge_after: bool = True, merge_strategy: str = 'concat')[source]
Bases:
objectConfiguration for source branching operations.
This dataclass provides configuration for the source_branch keyword, which creates per-source pipeline execution paths. Each source in a multi-source dataset gets its own independent processing pipeline.
Unlike regular branching (branch), which creates parallel paths that all process the same data, source branching assigns each source to a specific processing pipeline based on its name or index.
- source_pipelines
Mapping of source names/indices to their pipeline steps. - Dict[str, List]: Named sources to steps mapping - Dict[int, List]: Source indices to steps mapping - “auto”: Apply same steps to all sources independently
- default_pipeline
Default pipeline for sources not explicitly specified. Applied when a source is not listed in source_pipelines. If None, unspecified sources are passed through unchanged.
- Type:
List[Any] | None
- merge_after
Whether to automatically merge sources after branching. - True (default): Automatically call merge_sources after - False: Keep sources separate (user must merge manually)
- Type:
- merge_strategy
Strategy for auto-merge (when merge_after=True). - “concat” (default): Horizontal concatenation - “stack”: Stack along source axis - “dict”: Keep as dictionary
- Type:
Example
>>> # Different preprocessing per source >>> {"source_branch": { ... "NIR": [SNV(), SavitzkyGolay()], ... "markers": [VarianceThreshold(), MinMaxScaler()], ... "Raman": [BaselineCorrection(), StandardScaler()] ... }} >>> >>> # Source branching with default fallback >>> {"source_branch": { ... "NIR": [SNV()], ... "_default_": [MinMaxScaler()] # Applied to other sources ... }} >>> >>> # Automatic same-preprocessing per source (isolates sources) >>> {"source_branch": "auto"} >>> >>> # Source branching without auto-merge >>> {"source_branch": { ... "NIR": [SNV()], ... "markers": [StandardScaler()], ... "_merge_after_": False # Disable auto-merge ... }}
- classmethod from_dict(data: Dict[str, Any]) SourceBranchConfig[source]
Create config from dictionary.
Note: This is primarily for metadata reconstruction. The actual pipeline steps must be restored from the manifest/artifacts.
- Parameters:
data – Dictionary representation.
- Returns:
SourceBranchConfig instance (with placeholder pipelines).
- get_all_source_mappings(available_sources: List[str]) Dict[str, List[Any]][source]
Get pipeline mapping for all available sources.
- Parameters:
available_sources – List of available source names.
- Returns:
Dict mapping source names to their pipeline steps.
- get_pipeline_for_source(source_name: str, source_index: int) List[Any] | None[source]
Get pipeline steps for a specific source.
- Parameters:
source_name – Name of the source.
source_index – Index of the source.
- Returns:
List of pipeline steps for this source, or None if passthrough.
- class nirs4all.operators.data.merge.SourceIncompatibleStrategy(value)[source]
Bases:
EnumHow to handle incompatible source shapes during stacking.
When using stack strategy with sources that have different feature dimensions or processing counts, this controls the resolution.
- ERROR
Raise an error on incompatible shapes (default, strictest).
- FLATTEN
Force 2D concatenation instead of stacking.
- PAD
Pad shorter sources with zeros to match longest.
- TRUNCATE
Truncate longer sources to match shortest.
- ERROR = 'error'
- FLATTEN = 'flatten'
- PAD = 'pad'
- TRUNCATE = 'truncate'
- class nirs4all.operators.data.merge.SourceMergeConfig(strategy: str = 'concat', sources: str | List[int | str] = 'all', on_incompatible: str = 'error', output_name: str = 'merged', preserve_source_info: bool = True)[source]
Bases:
objectConfiguration for merging multi-source dataset features.
This dataclass provides configuration for the merge_sources keyword, which combines features from multiple data sources (e.g., NIR, markers, Raman) into a unified feature space.
Unlike branch merging (merge), source merging operates on the data provenance dimension—combining features that originated from different sensors, instruments, or data modalities.
- strategy
How to combine source features. - “concat” (default): Horizontal concatenation (2D result) - “stack”: Stack along new axis (3D result, requires uniform shapes) - “dict”: Keep as structured dictionary (for multi-input models)
- Type:
- sources
Which sources to include. - “all” (default): Include all available sources - List of source indices: [0, 1] for specific sources - List of source names: [“NIR”, “markers”] for named sources
- on_incompatible
How to handle incompatible shapes (for stack strategy). - “error” (default): Raise error if shapes don’t match - “flatten”: Fall back to 2D concat - “pad”: Pad shorter with zeros - “truncate”: Truncate longer to match shortest
- Type:
Example
>>> # Simple concatenation (default) >>> {"merge_sources": "concat"} >>> >>> # Stack for 3D models (requires same feature count per source) >>> {"merge_sources": {"strategy": "stack"}} >>> >>> # Selective sources with fallback on shape mismatch >>> {"merge_sources": { ... "strategy": "stack", ... "sources": ["NIR", "MIR"], ... "on_incompatible": "flatten" ... }} >>> >>> # Dict output for multi-head models >>> {"merge_sources": {"strategy": "dict"}}
- classmethod from_dict(data: Dict[str, Any]) SourceMergeConfig[source]
Create config from dictionary.
- Parameters:
data – Dictionary representation.
- Returns:
SourceMergeConfig instance.
- get_incompatible_strategy() SourceIncompatibleStrategy[source]
Get the incompatible handling strategy as an enum.
- Returns:
SourceIncompatibleStrategy enum value.
- get_source_indices(available_sources: List[str]) List[int][source]
Resolve source specification to indices.
- Parameters:
available_sources – List of available source names.
- Returns:
List of source indices to include.
- Raises:
ValueError – If a specified source is not found.
- get_strategy() SourceMergeStrategy[source]
Get the merge strategy as an enum.
- Returns:
SourceMergeStrategy enum value.
- class nirs4all.operators.data.merge.SourceMergeStrategy(value)[source]
Bases:
EnumHow to combine features from multiple data sources.
Used by the merge_sources keyword to control how multi-source datasets are unified into a single feature space.
- CONCAT
Horizontal concatenation of all source features (default). Results in 2D array: (samples, sum_of_all_source_features). Different feature dimensions per source is expected.
- STACK
Stack sources along a new axis to create 3D tensor. Results in 3D array: (samples, n_sources, n_features). Requires all sources to have the same feature dimension.
- DICT
Keep sources as a structured dictionary. Results in Dict[str, ndarray] for multi-input models. Each source is accessible by name.
- CONCAT = 'concat'
- DICT = 'dict'
- STACK = 'stack'