nirs4all.operators.data.merge module

Merge operator configuration for branch and source merging.

This module provides configuration dataclasses and enums for the MergeController, which handles combining branch outputs (features and/or predictions) and exiting branch mode.

The merge operator is the core primitive for all branch combination operations. It provides: - Feature merging from branches (horizontal concatenation) - Prediction merging with OOF reconstruction (data leakage prevention) - Per-branch model selection and aggregation strategies - Mixed merging (features from some branches, predictions from others)

Example

>>> # Simple feature merge
>>> {"merge": "features"}
>>>
>>> # Prediction merge with OOF safety
>>> {"merge": "predictions"}
>>>
>>> # Mixed merge with per-branch control
>>> {"merge": {
...     "predictions": [
...         {"branch": 0, "select": "best", "metric": "rmse"},
...         {"branch": 1, "aggregate": "mean"}
...     ],
...     "features": [2]
... }}
class nirs4all.operators.data.merge.AggregationStrategy(value)[source]

Bases: Enum

How to aggregate predictions from selected models within a branch.

After model selection, this controls how the selected predictions are combined into features for the merged output.

SEPARATE

Keep each model’s predictions as separate features (default). Results in N features (one per selected model).

MEAN

Simple average of all selected model predictions. Results in 1 feature.

WEIGHTED_MEAN

Weighted average by validation score. Results in 1 feature.

PROBA_MEAN

Average class probabilities (classification only). Results in K features (one per class).

MEAN = 'mean'
PROBA_MEAN = 'proba_mean'
SEPARATE = 'separate'
WEIGHTED_MEAN = 'weighted_mean'
class nirs4all.operators.data.merge.BranchPredictionConfig(branch: int | str, select: str | Dict[str, Any] | List[str] = 'all', metric: str | None = None, aggregate: str = 'separate', weight_metric: str | None = None, proba: bool = False, sources: str | List[int | str] = 'all')[source]

Bases: object

Configuration for prediction collection from a single branch.

This dataclass specifies how to collect and process predictions from a specific branch during merge operations.

branch

Branch index or name to collect from.

Type:

int | str

select

Model selection strategy. - “all” (default): All models in branch - “best”: Single best model by metric - {“top_k”: N}: Top N models by metric - [“ModelA”, “ModelB”]: Explicit model names

Type:

str | Dict[str, Any] | List[str]

metric

Metric for selection (rmse, mae, r2, accuracy, f1). Default is task-appropriate (rmse for regression, accuracy for classification).

Type:

str | None

aggregate

How to combine predictions from selected models. - “separate” (default): Each model is a separate feature - “mean”: Simple average of predictions - “weighted_mean”: Weight by validation score - “proba_mean”: Average class probabilities (classification)

Type:

str

weight_metric

Metric for weighted aggregation (default: same as metric).

Type:

str | None

proba

Use class probabilities instead of predictions (classification only).

Type:

bool

sources

Source filter for multi-source datasets. - “all” (default): Include all sources - List of source indices or names

Type:

str | List[int | str]

Example

>>> # Best model from branch 0 by RMSE
>>> BranchPredictionConfig(branch=0, select="best", metric="rmse")
>>>
>>> # Top 3 models from branch 1, averaged
>>> BranchPredictionConfig(
...     branch=1,
...     select={"top_k": 3},
...     metric="r2",
...     aggregate="mean"
... )
>>>
>>> # Explicit models with weighted average
>>> BranchPredictionConfig(
...     branch="spectral_path",
...     select=["PLS", "RF"],
...     aggregate="weighted_mean",
...     weight_metric="r2"
... )
__post_init__()[source]

Validate configuration after initialization.

aggregate: str = 'separate'
branch: int | str
get_aggregation_strategy() AggregationStrategy[source]

Get the aggregation strategy enum for this configuration.

Returns:

AggregationStrategy enum value based on aggregate field.

get_selection_strategy() SelectionStrategy[source]

Get the selection strategy enum for this configuration.

Returns:

SelectionStrategy enum value based on select field.

metric: str | None = None
proba: bool = False
select: str | Dict[str, Any] | List[str] = 'all'
sources: str | List[int | str] = 'all'
weight_metric: str | None = None
class nirs4all.operators.data.merge.BranchType(value)[source]

Bases: Enum

Type of branch based on sample handling.

COPY

All branches see all samples (default branching behavior).

METADATA_PARTITIONER

Branches partition samples by metadata column.

SAMPLE_PARTITIONER

Branches partition samples by filter (e.g., outlier).

COPY = 'copy'
METADATA_PARTITIONER = 'metadata_partitioner'
SAMPLE_PARTITIONER = 'sample_partitioner'
class nirs4all.operators.data.merge.DisjointBranchInfo(n_samples: int, sample_ids: ~typing.List[int], n_models_original: int = 0, n_models_selected: int = 0, selected_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>, dropped_models: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>)[source]

Bases: object

Information about a single branch in a disjoint merge.

Captures per-branch statistics and model selection details for comprehensive merge metadata.

n_samples

Number of samples in this branch partition.

Type:

int

sample_ids

List of sample indices belonging to this branch.

Type:

List[int]

n_models_original

Original number of models in the branch.

Type:

int

n_models_selected

Number of models selected for merge.

Type:

int

selected_models

List of selected model details with name, score, column.

Type:

List[Dict[str, Any]]

dropped_models

List of dropped model details with name, score.

Type:

List[Dict[str, Any]]

dropped_models: List[Dict[str, Any]]
n_models_original: int = 0
n_models_selected: int = 0
n_samples: int
sample_ids: List[int]
selected_models: List[Dict[str, Any]]
to_dict() Dict[str, Any][source]

Convert to dictionary for serialization.

class nirs4all.operators.data.merge.DisjointMergeMetadata(merge_type: str = 'disjoint_samples', n_columns: int = 0, select_by: str = 'mse', branches: Dict[str, ~nirs4all.operators.data.merge.DisjointBranchInfo]=<factory>, column_mapping: Dict[int, ~typing.Dict[str, str]]=<factory>, is_heterogeneous: bool = False, feature_dim: int | None = None)[source]

Bases: object

Complete metadata for a disjoint sample branch merge.

This dataclass captures all information about a disjoint merge operation for logging, debugging, and downstream use. Matches the specification in docs/reports/disjoint_sample_branch_merging.md Section 6.

merge_type

Always “disjoint_samples” for disjoint merges.

Type:

str

n_columns

Number of output columns (prediction features).

Type:

int

select_by

Selection criterion used (mse, rmse, mae, r2, order).

Type:

str

branches

Per-branch information as Dict[branch_name, DisjointBranchInfo].

Type:

Dict[str, nirs4all.operators.data.merge.DisjointBranchInfo]

column_mapping

Maps output column index to per-branch model names. Example: {0: {“red”: “RF”, “blue”: “PLS”}, 1: {“red”: “PLS”, “blue”: “RF”}}

Type:

Dict[int, Dict[str, str]]

is_heterogeneous

True if different branches have different models per column.

Type:

bool

feature_dim

Feature dimension (for feature merges).

Type:

int | None

Example

>>> metadata = DisjointMergeMetadata(
...     merge_type="disjoint_samples",
...     n_columns=2,
...     select_by="mse",
...     branches={
...         "red": DisjointBranchInfo(n_samples=50, sample_ids=[...], ...),
...         "blue": DisjointBranchInfo(n_samples=100, sample_ids=[...], ...),
...     },
...     column_mapping={
...         0: {"red": "RF", "blue": "PLS"},
...         1: {"red": "PLS", "blue": "RF"},
...     },
... )
branches: Dict[str, DisjointBranchInfo]
column_mapping: Dict[int, Dict[str, str]]
feature_dim: int | None = None
classmethod from_dict(data: Dict[str, Any]) DisjointMergeMetadata[source]

Create from dictionary representation.

Parameters:

data – Dictionary with metadata fields.

Returns:

DisjointMergeMetadata instance.

get_branch_summary() str[source]

Get a summary string for logging.

Returns:

Human-readable summary of branch statistics.

get_column_mapping_summary() List[str][source]

Get column mapping summary for logging.

Returns:

List of strings describing each column’s model mapping.

is_heterogeneous: bool = False
log_summary(logger_func) None[source]

Log merge summary using provided logger function.

Parameters:

logger_func – Logger function (e.g., logger.info)

log_warnings(logger_warning_func) None[source]

Log warnings for heterogeneous columns and dropped models.

Parameters:

logger_warning_func – Logger warning function (e.g., logger.warning)

merge_type: str = 'disjoint_samples'
n_columns: int = 0
select_by: str = 'mse'
to_dict() Dict[str, Any][source]

Convert to dictionary for serialization/logging.

Returns:

Dictionary representation suitable for YAML/JSON serialization.

class nirs4all.operators.data.merge.DisjointSelectionCriterion(value)[source]

Bases: Enum

Criterion for selecting top-N models in disjoint branch merge.

When branches have different model counts, we select top-N models from each branch based on this criterion.

MSE

Select by lowest Mean Squared Error (default for regression).

RMSE

Select by lowest Root Mean Squared Error.

MAE

Select by lowest Mean Absolute Error.

R2

Select by highest R² score.

ORDER

Select first N in definition order (no ranking).

MAE = 'mae'
MSE = 'mse'
ORDER = 'order'
R2 = 'r2'
RMSE = 'rmse'
class nirs4all.operators.data.merge.MergeConfig(collect_features: bool = False, feature_branches: str | List[int] = 'all', collect_predictions: bool = False, prediction_branches: str | List[int] = 'all', prediction_configs: List[BranchPredictionConfig] | None = None, model_filter: List[str] | None = None, use_proba: bool = False, include_original: bool = False, on_missing: str = 'error', on_shape_mismatch: str = 'error', unsafe: bool = False, output_as: str = 'features', source_names: List[str] | None = None, n_columns: int | None = None, select_by: str = 'mse')[source]

Bases: object

Configuration for branch merging operations.

This dataclass provides complete configuration for the MergeController, controlling what data is collected from branches and how it is combined.

collect_features

Whether to collect features from branches.

Type:

bool

feature_branches

Which branches to collect features from. - “all” (default): All branches - List of branch indices: [0, 2] for specific branches

Type:

str | List[int]

collect_predictions

Whether to collect predictions from branches.

Type:

bool

prediction_branches

Legacy simple mode: which branches for predictions. Use prediction_configs for advanced per-branch control.

Type:

str | List[int]

prediction_configs

Advanced per-branch prediction configuration. Takes precedence over prediction_branches when set.

Type:

List[nirs4all.operators.data.merge.BranchPredictionConfig] | None

model_filter

Legacy: global model filter (simple mode). List of model names to include.

Type:

List[str] | None

use_proba

Legacy: global proba setting for classification.

Type:

bool

include_original

Include pre-branch features in merged output. When True, original features are prepended to merged features.

Type:

bool

on_missing

How to handle missing branches or predictions. - “error” (default): Raise an error - “warn”: Log warning and skip - “skip”: Silent skip

Type:

str

on_shape_mismatch

Reserved for 3D layout feature merging. In 2D layout (default), features are flattened and concatenated horizontally, so different feature dimensions is normal and this parameter has no effect. For future 3D layout support: - “error”: Raise error if processings differ - “allow”: Flatten to 2D and concatenate - “pad”: Pad shorter processings with zeros - “truncate”: Truncate longer to match shortest

Type:

str

unsafe

If True, DISABLE OOF reconstruction for predictions. ⚠️ CAUSES DATA LEAKAGE - only for rapid prototyping.

Type:

bool

output_as

Where to put merged output. - “features” (default): Single concatenated feature matrix - “sources”: Each branch becomes a separate source - “dict”: Keep as structured dict for multi-head models

Type:

str

source_names

Custom names for output sources (when output_as=”sources”). If not provided, uses “branch_0”, “branch_1”, etc.

Type:

List[str] | None

Example

>>> # Simple feature merge
>>> MergeConfig(collect_features=True)
>>>
>>> # Prediction merge with OOF
>>> MergeConfig(collect_predictions=True)
>>>
>>> # Mixed merge with per-branch control
>>> MergeConfig(
...     collect_predictions=True,
...     prediction_configs=[
...         BranchPredictionConfig(branch=0, select="best"),
...         BranchPredictionConfig(branch=1, aggregate="mean")
...     ],
...     collect_features=True,
...     feature_branches=[2]
... )
>>>
>>> # Unsafe mode (with warning)
>>> MergeConfig(collect_predictions=True, unsafe=True)
>>>
>>> # Disjoint branch merge with n_columns override
>>> MergeConfig(
...     collect_predictions=True,
...     n_columns=2,
...     select_by="mse"
... )
__post_init__()[source]

Validate configuration after initialization.

collect_features: bool = False
collect_predictions: bool = False
feature_branches: str | List[int] = 'all'
classmethod from_dict(data: Dict[str, Any]) MergeConfig[source]

Create MergeConfig from a dictionary.

Used for loading merge configuration from manifest in prediction mode.

Parameters:

data – Dictionary representation of merge configuration.

Returns:

MergeConfig instance.

get_feature_branches(n_branches: int) List[int][source]

Get list of branch indices to collect features from.

Parameters:

n_branches – Total number of branches available.

Returns:

List of branch indices.

get_merge_mode() MergeMode[source]

Determine the merge mode based on configuration.

Returns:

MergeMode enum value.

get_prediction_configs(n_branches: int) List[BranchPredictionConfig][source]

Get prediction configurations, normalizing legacy format if needed.

Converts legacy simple mode (prediction_branches + model_filter + use_proba) to per-branch configurations for uniform processing.

Parameters:

n_branches – Total number of branches available.

Returns:

List of BranchPredictionConfig for each branch to collect from.

get_selection_criterion() DisjointSelectionCriterion[source]

Get the selection criterion enum for disjoint branch merging.

Returns:

DisjointSelectionCriterion enum value.

get_shape_mismatch_strategy() ShapeMismatchStrategy[source]

Get the shape mismatch strategy enum.

Returns:

ShapeMismatchStrategy enum value.

has_per_branch_config() bool[source]

Check if using advanced per-branch prediction configuration.

Returns:

True if prediction_configs is set and non-empty.

include_original: bool = False
model_filter: List[str] | None = None
n_columns: int | None = None
on_missing: str = 'error'
on_shape_mismatch: str = 'error'
output_as: str = 'features'
prediction_branches: str | List[int] = 'all'
prediction_configs: List[BranchPredictionConfig] | None = None
select_by: str = 'mse'
source_names: List[str] | None = None
to_dict() Dict[str, Any][source]

Serialize merge configuration to a dictionary.

Used for saving merge configuration to manifest for reproducibility in prediction mode and bundle export.

Returns:

Dictionary representation suitable for YAML/JSON serialization.

unsafe: bool = False
use_proba: bool = False
class nirs4all.operators.data.merge.MergeMode(value)[source]

Bases: Enum

What to merge from branches.

FEATURES

Merge feature matrices from branches.

PREDICTIONS

Merge model predictions from branches (with OOF reconstruction).

ALL

Merge both features and predictions from all branches.

ALL = 'all'
FEATURES = 'features'
PREDICTIONS = 'predictions'
class nirs4all.operators.data.merge.SelectionStrategy(value)[source]

Bases: Enum

How to select models within a branch for prediction merging.

When a branch contains multiple models, this controls which models’ predictions are included in the merge.

ALL

Include all models in the branch (default).

BEST

Single best model by specified metric.

TOP_K

Top K models by specified metric.

EXPLICIT

Explicit list of model names.

ALL = 'all'
BEST = 'best'
EXPLICIT = 'explicit'
TOP_K = 'top_k'
class nirs4all.operators.data.merge.ShapeMismatchStrategy(value)[source]

Bases: Enum

How to handle shape mismatches during 3D feature merging.

This strategy only applies when using 3D layout for features, where the number of processings must be aligned across branches. In 2D layout (the default), features are simply flattened and concatenated horizontally, so different feature dimensions across branches is expected and normal.

Example

  • Branch 0: (200 samples, 500 features) from MinMaxScaler

  • Branch 1: (200 samples, 4 processings, 20 features) from multi-processing

In 2D layout: concatenates to (200, 500 + 4*20 = 580) - no error In 3D layout: needs alignment strategy since processings differ

ERROR

Raise an error on shape mismatch (default, strictest).

ALLOW

Flatten to 2D and concatenate regardless of differences.

PAD

Pad shorter branches with zeros to match longest processings.

TRUNCATE

Truncate longer branches to match shortest processings.

ALLOW = 'allow'
ERROR = 'error'
PAD = 'pad'
TRUNCATE = 'truncate'
class nirs4all.operators.data.merge.SourceBranchConfig(source_pipelines: str | ~typing.Dict[str | int, ~typing.List[~typing.Any]]=<factory>, default_pipeline: List[Any] | None = None, merge_after: bool = True, merge_strategy: str = 'concat')[source]

Bases: object

Configuration for source branching operations.

This dataclass provides configuration for the source_branch keyword, which creates per-source pipeline execution paths. Each source in a multi-source dataset gets its own independent processing pipeline.

Unlike regular branching (branch), which creates parallel paths that all process the same data, source branching assigns each source to a specific processing pipeline based on its name or index.

source_pipelines

Mapping of source names/indices to their pipeline steps. - Dict[str, List]: Named sources to steps mapping - Dict[int, List]: Source indices to steps mapping - “auto”: Apply same steps to all sources independently

Type:

str | Dict[str | int, List[Any]]

default_pipeline

Default pipeline for sources not explicitly specified. Applied when a source is not listed in source_pipelines. If None, unspecified sources are passed through unchanged.

Type:

List[Any] | None

merge_after

Whether to automatically merge sources after branching. - True (default): Automatically call merge_sources after - False: Keep sources separate (user must merge manually)

Type:

bool

merge_strategy

Strategy for auto-merge (when merge_after=True). - “concat” (default): Horizontal concatenation - “stack”: Stack along source axis - “dict”: Keep as dictionary

Type:

str

Example

>>> # Different preprocessing per source
>>> {"source_branch": {
...     "NIR": [SNV(), SavitzkyGolay()],
...     "markers": [VarianceThreshold(), MinMaxScaler()],
...     "Raman": [BaselineCorrection(), StandardScaler()]
... }}
>>>
>>> # Source branching with default fallback
>>> {"source_branch": {
...     "NIR": [SNV()],
...     "_default_": [MinMaxScaler()]  # Applied to other sources
... }}
>>>
>>> # Automatic same-preprocessing per source (isolates sources)
>>> {"source_branch": "auto"}
>>>
>>> # Source branching without auto-merge
>>> {"source_branch": {
...     "NIR": [SNV()],
...     "markers": [StandardScaler()],
...     "_merge_after_": False  # Disable auto-merge
... }}
__post_init__()[source]

Validate configuration after initialization.

default_pipeline: List[Any] | None = None
classmethod from_dict(data: Dict[str, Any]) SourceBranchConfig[source]

Create config from dictionary.

Note: This is primarily for metadata reconstruction. The actual pipeline steps must be restored from the manifest/artifacts.

Parameters:

data – Dictionary representation.

Returns:

SourceBranchConfig instance (with placeholder pipelines).

get_all_source_mappings(available_sources: List[str]) Dict[str, List[Any]][source]

Get pipeline mapping for all available sources.

Parameters:

available_sources – List of available source names.

Returns:

Dict mapping source names to their pipeline steps.

get_pipeline_for_source(source_name: str, source_index: int) List[Any] | None[source]

Get pipeline steps for a specific source.

Parameters:
  • source_name – Name of the source.

  • source_index – Index of the source.

Returns:

List of pipeline steps for this source, or None if passthrough.

is_auto_mode() bool[source]

Check if using automatic source branching.

Returns:

True if source_pipelines is “auto”.

merge_after: bool = True
merge_strategy: str = 'concat'
source_pipelines: str | Dict[str | int, List[Any]]
to_dict() Dict[str, Any][source]

Serialize configuration to dictionary.

Returns:

Dictionary representation for manifest storage.

class nirs4all.operators.data.merge.SourceIncompatibleStrategy(value)[source]

Bases: Enum

How to handle incompatible source shapes during stacking.

When using stack strategy with sources that have different feature dimensions or processing counts, this controls the resolution.

ERROR

Raise an error on incompatible shapes (default, strictest).

FLATTEN

Force 2D concatenation instead of stacking.

PAD

Pad shorter sources with zeros to match longest.

TRUNCATE

Truncate longer sources to match shortest.

ERROR = 'error'
FLATTEN = 'flatten'
PAD = 'pad'
TRUNCATE = 'truncate'
class nirs4all.operators.data.merge.SourceMergeConfig(strategy: str = 'concat', sources: str | List[int | str] = 'all', on_incompatible: str = 'error', output_name: str = 'merged', preserve_source_info: bool = True)[source]

Bases: object

Configuration for merging multi-source dataset features.

This dataclass provides configuration for the merge_sources keyword, which combines features from multiple data sources (e.g., NIR, markers, Raman) into a unified feature space.

Unlike branch merging (merge), source merging operates on the data provenance dimension—combining features that originated from different sensors, instruments, or data modalities.

strategy

How to combine source features. - “concat” (default): Horizontal concatenation (2D result) - “stack”: Stack along new axis (3D result, requires uniform shapes) - “dict”: Keep as structured dictionary (for multi-input models)

Type:

str

sources

Which sources to include. - “all” (default): Include all available sources - List of source indices: [0, 1] for specific sources - List of source names: [“NIR”, “markers”] for named sources

Type:

str | List[int | str]

on_incompatible

How to handle incompatible shapes (for stack strategy). - “error” (default): Raise error if shapes don’t match - “flatten”: Fall back to 2D concat - “pad”: Pad shorter with zeros - “truncate”: Truncate longer to match shortest

Type:

str

output_name

Name for the merged output source (default: “merged”).

Type:

str

preserve_source_info

Whether to store source metadata for debugging.

Type:

bool

Example

>>> # Simple concatenation (default)
>>> {"merge_sources": "concat"}
>>>
>>> # Stack for 3D models (requires same feature count per source)
>>> {"merge_sources": {"strategy": "stack"}}
>>>
>>> # Selective sources with fallback on shape mismatch
>>> {"merge_sources": {
...     "strategy": "stack",
...     "sources": ["NIR", "MIR"],
...     "on_incompatible": "flatten"
... }}
>>>
>>> # Dict output for multi-head models
>>> {"merge_sources": {"strategy": "dict"}}
__post_init__()[source]

Validate configuration after initialization.

classmethod from_dict(data: Dict[str, Any]) SourceMergeConfig[source]

Create config from dictionary.

Parameters:

data – Dictionary representation.

Returns:

SourceMergeConfig instance.

get_incompatible_strategy() SourceIncompatibleStrategy[source]

Get the incompatible handling strategy as an enum.

Returns:

SourceIncompatibleStrategy enum value.

get_source_indices(available_sources: List[str]) List[int][source]

Resolve source specification to indices.

Parameters:

available_sources – List of available source names.

Returns:

List of source indices to include.

Raises:

ValueError – If a specified source is not found.

get_strategy() SourceMergeStrategy[source]

Get the merge strategy as an enum.

Returns:

SourceMergeStrategy enum value.

on_incompatible: str = 'error'
output_name: str = 'merged'
preserve_source_info: bool = True
sources: str | List[int | str] = 'all'
strategy: str = 'concat'
to_dict() Dict[str, Any][source]

Serialize configuration to dictionary.

Returns:

Dictionary representation for manifest storage.

class nirs4all.operators.data.merge.SourceMergeStrategy(value)[source]

Bases: Enum

How to combine features from multiple data sources.

Used by the merge_sources keyword to control how multi-source datasets are unified into a single feature space.

CONCAT

Horizontal concatenation of all source features (default). Results in 2D array: (samples, sum_of_all_source_features). Different feature dimensions per source is expected.

STACK

Stack sources along a new axis to create 3D tensor. Results in 3D array: (samples, n_sources, n_features). Requires all sources to have the same feature dimension.

DICT

Keep sources as a structured dictionary. Results in Dict[str, ndarray] for multi-input models. Each source is accessible by name.

CONCAT = 'concat'
DICT = 'dict'
STACK = 'stack'