nirs4all.operators.data package

Submodules

Module contents

Data operators for merge and source handling.

This module provides operators for branch merging, source handling, and related data manipulation operations.

class nirs4all.operators.data.AggregationStrategy(value)[source]

Bases: Enum

How to aggregate predictions from selected models within a branch.

After model selection, this controls how the selected predictions are combined into features for the merged output.

SEPARATE

Keep each model’s predictions as separate features (default). Results in N features (one per selected model).

MEAN

Simple average of all selected model predictions. Results in 1 feature.

WEIGHTED_MEAN

Weighted average by validation score. Results in 1 feature.

PROBA_MEAN

Average class probabilities (classification only). Results in K features (one per class).

MEAN = 'mean'
PROBA_MEAN = 'proba_mean'
SEPARATE = 'separate'
WEIGHTED_MEAN = 'weighted_mean'
class nirs4all.operators.data.BranchPredictionConfig(branch: int | str, select: str | Dict[str, Any] | List[str] = 'all', metric: str | None = None, aggregate: str = 'separate', weight_metric: str | None = None, proba: bool = False, sources: str | List[int | str] = 'all')[source]

Bases: object

Configuration for prediction collection from a single branch.

This dataclass specifies how to collect and process predictions from a specific branch during merge operations.

branch

Branch index or name to collect from.

Type:

int | str

select

Model selection strategy. - “all” (default): All models in branch - “best”: Single best model by metric - {“top_k”: N}: Top N models by metric - [“ModelA”, “ModelB”]: Explicit model names

Type:

str | Dict[str, Any] | List[str]

metric

Metric for selection (rmse, mae, r2, accuracy, f1). Default is task-appropriate (rmse for regression, accuracy for classification).

Type:

str | None

aggregate

How to combine predictions from selected models. - “separate” (default): Each model is a separate feature - “mean”: Simple average of predictions - “weighted_mean”: Weight by validation score - “proba_mean”: Average class probabilities (classification)

Type:

str

weight_metric

Metric for weighted aggregation (default: same as metric).

Type:

str | None

proba

Use class probabilities instead of predictions (classification only).

Type:

bool

sources

Source filter for multi-source datasets. - “all” (default): Include all sources - List of source indices or names

Type:

str | List[int | str]

Example

>>> # Best model from branch 0 by RMSE
>>> BranchPredictionConfig(branch=0, select="best", metric="rmse")
>>>
>>> # Top 3 models from branch 1, averaged
>>> BranchPredictionConfig(
...     branch=1,
...     select={"top_k": 3},
...     metric="r2",
...     aggregate="mean"
... )
>>>
>>> # Explicit models with weighted average
>>> BranchPredictionConfig(
...     branch="spectral_path",
...     select=["PLS", "RF"],
...     aggregate="weighted_mean",
...     weight_metric="r2"
... )
__post_init__()[source]

Validate configuration after initialization.

aggregate: str = 'separate'
branch: int | str
get_aggregation_strategy() AggregationStrategy[source]

Get the aggregation strategy enum for this configuration.

Returns:

AggregationStrategy enum value based on aggregate field.

get_selection_strategy() SelectionStrategy[source]

Get the selection strategy enum for this configuration.

Returns:

SelectionStrategy enum value based on select field.

metric: str | None = None
proba: bool = False
select: str | Dict[str, Any] | List[str] = 'all'
sources: str | List[int | str] = 'all'
weight_metric: str | None = None
class nirs4all.operators.data.BranchType(value)[source]

Bases: Enum

Type of branch based on sample handling.

COPY

All branches see all samples (default branching behavior).

METADATA_PARTITIONER

Branches partition samples by metadata column.

SAMPLE_PARTITIONER

Branches partition samples by filter (e.g., outlier).

COPY = 'copy'
METADATA_PARTITIONER = 'metadata_partitioner'
SAMPLE_PARTITIONER = 'sample_partitioner'
class nirs4all.operators.data.DisjointSelectionCriterion(value)[source]

Bases: Enum

Criterion for selecting top-N models in disjoint branch merge.

When branches have different model counts, we select top-N models from each branch based on this criterion.

MSE

Select by lowest Mean Squared Error (default for regression).

RMSE

Select by lowest Root Mean Squared Error.

MAE

Select by lowest Mean Absolute Error.

R2

Select by highest R² score.

ORDER

Select first N in definition order (no ranking).

MAE = 'mae'
MSE = 'mse'
ORDER = 'order'
R2 = 'r2'
RMSE = 'rmse'
class nirs4all.operators.data.MergeConfig(collect_features: bool = False, feature_branches: str | List[int] = 'all', collect_predictions: bool = False, prediction_branches: str | List[int] = 'all', prediction_configs: List[BranchPredictionConfig] | None = None, model_filter: List[str] | None = None, use_proba: bool = False, include_original: bool = False, on_missing: str = 'error', on_shape_mismatch: str = 'error', unsafe: bool = False, output_as: str = 'features', source_names: List[str] | None = None, n_columns: int | None = None, select_by: str = 'mse')[source]

Bases: object

Configuration for branch merging operations.

This dataclass provides complete configuration for the MergeController, controlling what data is collected from branches and how it is combined.

collect_features

Whether to collect features from branches.

Type:

bool

feature_branches

Which branches to collect features from. - “all” (default): All branches - List of branch indices: [0, 2] for specific branches

Type:

str | List[int]

collect_predictions

Whether to collect predictions from branches.

Type:

bool

prediction_branches

Legacy simple mode: which branches for predictions. Use prediction_configs for advanced per-branch control.

Type:

str | List[int]

prediction_configs

Advanced per-branch prediction configuration. Takes precedence over prediction_branches when set.

Type:

List[nirs4all.operators.data.merge.BranchPredictionConfig] | None

model_filter

Legacy: global model filter (simple mode). List of model names to include.

Type:

List[str] | None

use_proba

Legacy: global proba setting for classification.

Type:

bool

include_original

Include pre-branch features in merged output. When True, original features are prepended to merged features.

Type:

bool

on_missing

How to handle missing branches or predictions. - “error” (default): Raise an error - “warn”: Log warning and skip - “skip”: Silent skip

Type:

str

on_shape_mismatch

Reserved for 3D layout feature merging. In 2D layout (default), features are flattened and concatenated horizontally, so different feature dimensions is normal and this parameter has no effect. For future 3D layout support: - “error”: Raise error if processings differ - “allow”: Flatten to 2D and concatenate - “pad”: Pad shorter processings with zeros - “truncate”: Truncate longer to match shortest

Type:

str

unsafe

If True, DISABLE OOF reconstruction for predictions. ⚠️ CAUSES DATA LEAKAGE - only for rapid prototyping.

Type:

bool

output_as

Where to put merged output. - “features” (default): Single concatenated feature matrix - “sources”: Each branch becomes a separate source - “dict”: Keep as structured dict for multi-head models

Type:

str

source_names

Custom names for output sources (when output_as=”sources”). If not provided, uses “branch_0”, “branch_1”, etc.

Type:

List[str] | None

Example

>>> # Simple feature merge
>>> MergeConfig(collect_features=True)
>>>
>>> # Prediction merge with OOF
>>> MergeConfig(collect_predictions=True)
>>>
>>> # Mixed merge with per-branch control
>>> MergeConfig(
...     collect_predictions=True,
...     prediction_configs=[
...         BranchPredictionConfig(branch=0, select="best"),
...         BranchPredictionConfig(branch=1, aggregate="mean")
...     ],
...     collect_features=True,
...     feature_branches=[2]
... )
>>>
>>> # Unsafe mode (with warning)
>>> MergeConfig(collect_predictions=True, unsafe=True)
>>>
>>> # Disjoint branch merge with n_columns override
>>> MergeConfig(
...     collect_predictions=True,
...     n_columns=2,
...     select_by="mse"
... )
__post_init__()[source]

Validate configuration after initialization.

collect_features: bool = False
collect_predictions: bool = False
feature_branches: str | List[int] = 'all'
classmethod from_dict(data: Dict[str, Any]) MergeConfig[source]

Create MergeConfig from a dictionary.

Used for loading merge configuration from manifest in prediction mode.

Parameters:

data – Dictionary representation of merge configuration.

Returns:

MergeConfig instance.

get_feature_branches(n_branches: int) List[int][source]

Get list of branch indices to collect features from.

Parameters:

n_branches – Total number of branches available.

Returns:

List of branch indices.

get_merge_mode() MergeMode[source]

Determine the merge mode based on configuration.

Returns:

MergeMode enum value.

get_prediction_configs(n_branches: int) List[BranchPredictionConfig][source]

Get prediction configurations, normalizing legacy format if needed.

Converts legacy simple mode (prediction_branches + model_filter + use_proba) to per-branch configurations for uniform processing.

Parameters:

n_branches – Total number of branches available.

Returns:

List of BranchPredictionConfig for each branch to collect from.

get_selection_criterion() DisjointSelectionCriterion[source]

Get the selection criterion enum for disjoint branch merging.

Returns:

DisjointSelectionCriterion enum value.

get_shape_mismatch_strategy() ShapeMismatchStrategy[source]

Get the shape mismatch strategy enum.

Returns:

ShapeMismatchStrategy enum value.

has_per_branch_config() bool[source]

Check if using advanced per-branch prediction configuration.

Returns:

True if prediction_configs is set and non-empty.

include_original: bool = False
model_filter: List[str] | None = None
n_columns: int | None = None
on_missing: str = 'error'
on_shape_mismatch: str = 'error'
output_as: str = 'features'
prediction_branches: str | List[int] = 'all'
prediction_configs: List[BranchPredictionConfig] | None = None
select_by: str = 'mse'
source_names: List[str] | None = None
to_dict() Dict[str, Any][source]

Serialize merge configuration to a dictionary.

Used for saving merge configuration to manifest for reproducibility in prediction mode and bundle export.

Returns:

Dictionary representation suitable for YAML/JSON serialization.

unsafe: bool = False
use_proba: bool = False
class nirs4all.operators.data.MergeMode(value)[source]

Bases: Enum

What to merge from branches.

FEATURES

Merge feature matrices from branches.

PREDICTIONS

Merge model predictions from branches (with OOF reconstruction).

ALL

Merge both features and predictions from all branches.

ALL = 'all'
FEATURES = 'features'
PREDICTIONS = 'predictions'
class nirs4all.operators.data.RepetitionConfig(column: str | None = None, on_unequal: str = 'error', expected_reps: int | None = None, source_names: str | List[str] | None = None, pp_names: str | List[str] | None = None, preserve_order: bool = True, aggregate_metadata: str = 'first')[source]

Bases: object

Configuration for repetition transformation operations.

This dataclass provides configuration for rep_to_sources and rep_to_pp keywords, which reshape datasets based on sample repetitions.

Repetitions are identified by a metadata column (e.g., “Sample_ID”) that groups multiple spectra belonging to the same physical sample.

column

Metadata column identifying sample groups, or special values: - None (default): Use dataset’s aggregate column from DatasetConfigs - “y”: Group by target values - str: Explicit metadata column name

Type:

str | None

on_unequal

Strategy when samples have different repetition counts. - “error” (default): Raise error if counts differ - “pad”: Pad shorter groups with NaN to match longest - “drop”: Drop samples without expected repetition count - “truncate”: Use minimum count across all samples

Type:

str

expected_reps

Expected number of repetitions per sample. If None (default), inferred from data (mode of group sizes). If specified, validates all groups match this count.

Type:

int | None

source_names

Naming template for new sources (rep_to_sources only). - None (default): Uses “rep_0”, “rep_1”, etc. - str with {i}: Template like “rep_{i}” or “spectrum_{i}” - List[str]: Explicit names for each repetition

Type:

str | List[str] | None

pp_names

Naming template for new preprocessings (rep_to_pp only). - None (default): Uses “{original}_rep{i}” format - str with {i} and {pp}: Template like “{pp}_r{i}” - List[str]: Explicit names (length = n_reps * n_existing_pp)

Type:

str | List[str] | None

preserve_order

Whether to preserve sample order within groups. If True (default), repetitions are ordered by their row position. If False, order within groups is undefined.

Type:

bool

aggregate_metadata

How to handle metadata after grouping. - “first” (default): Keep metadata from first repetition - “validate”: Ensure all reps have identical metadata, error if not - “drop”: Remove metadata columns that differ across repetitions

Type:

str

Example

>>> # Use dataset's aggregate column (simplest)
>>> RepetitionConfig()
>>>
>>> # Simple column-based grouping
>>> RepetitionConfig(column="Sample_ID")
>>>
>>> # Group by target value with padding
>>> RepetitionConfig(column="y", on_unequal="pad")
>>>
>>> # Explicit repetition count validation
>>> RepetitionConfig(
...     column="Leaf_ID",
...     expected_reps=4,
...     on_unequal="error"
... )
>>>
>>> # Custom source naming
>>> RepetitionConfig(
...     column="Sample_ID",
...     source_names="measurement_{i}"
... )
__post_init__()[source]

Validate configuration after initialization.

aggregate_metadata: str = 'first'
column: str | None = None
expected_reps: int | None = None
classmethod from_dict(data: Dict[str, Any]) RepetitionConfig[source]

Create config from dictionary.

Parameters:

data – Dictionary representation. If ‘column’ is missing, uses None (aggregate).

Returns:

RepetitionConfig instance.

classmethod from_step_value(value: str | bool | Dict[str, Any] | None) RepetitionConfig[source]

Create config from step value (string, bool, or dict).

Handles multiple syntax styles: - None or True: Use dataset’s aggregate column - str: Explicit column name (or “y” for target grouping) - dict: Full configuration with options

Parameters:

value – Step value - column name, True/None for aggregate, or config dict.

Returns:

RepetitionConfig instance.

Example

>>> # Use dataset aggregate (simplest)
>>> RepetitionConfig.from_step_value(True)
>>> RepetitionConfig.from_step_value(None)
>>>
>>> # Explicit column
>>> RepetitionConfig.from_step_value("Sample_ID")
>>>
>>> # Advanced syntax
>>> RepetitionConfig.from_step_value({
...     "column": "Sample_ID",
...     "on_unequal": "drop"
... })
get_pp_name(rep_index: int, original_pp: str) str[source]

Generate preprocessing name for a given repetition and original processing.

Parameters:
  • rep_index – Zero-based repetition index.

  • original_pp – Original preprocessing name (e.g., “raw”, “snv”).

Returns:

New preprocessing name string.

get_source_name(rep_index: int) str[source]

Generate source name for a given repetition index.

Parameters:

rep_index – Zero-based repetition index.

Returns:

Source name string.

get_unequal_strategy() UnequelRepsStrategy[source]

Get the unequal handling strategy as an enum.

Returns:

UnequelRepsStrategy enum value.

property is_y_grouping: bool

Check if grouping by target values.

Returns:

True if column is “y” (case-insensitive).

on_unequal: str = 'error'
pp_names: str | List[str] | None = None
preserve_order: bool = True
resolve_column(dataset_aggregate: str | None) str[source]

Resolve the actual column to use for grouping.

Parameters:

dataset_aggregate – The aggregate value from dataset (column name, “y”, or None).

Returns:

The resolved column name to use.

Raises:

ValueError – If no column specified and dataset has no aggregate setting.

source_names: str | List[str] | None = None
to_dict() Dict[str, Any][source]

Serialize configuration to dictionary.

Returns:

Dictionary representation for manifest storage.

property uses_dataset_aggregate: bool

Check if using dataset’s aggregate column.

Returns:

True if column is None (will use dataset.aggregate at runtime).

class nirs4all.operators.data.SelectionStrategy(value)[source]

Bases: Enum

How to select models within a branch for prediction merging.

When a branch contains multiple models, this controls which models’ predictions are included in the merge.

ALL

Include all models in the branch (default).

BEST

Single best model by specified metric.

TOP_K

Top K models by specified metric.

EXPLICIT

Explicit list of model names.

ALL = 'all'
BEST = 'best'
EXPLICIT = 'explicit'
TOP_K = 'top_k'
class nirs4all.operators.data.ShapeMismatchStrategy(value)[source]

Bases: Enum

How to handle shape mismatches during 3D feature merging.

This strategy only applies when using 3D layout for features, where the number of processings must be aligned across branches. In 2D layout (the default), features are simply flattened and concatenated horizontally, so different feature dimensions across branches is expected and normal.

Example

  • Branch 0: (200 samples, 500 features) from MinMaxScaler

  • Branch 1: (200 samples, 4 processings, 20 features) from multi-processing

In 2D layout: concatenates to (200, 500 + 4*20 = 580) - no error In 3D layout: needs alignment strategy since processings differ

ERROR

Raise an error on shape mismatch (default, strictest).

ALLOW

Flatten to 2D and concatenate regardless of differences.

PAD

Pad shorter branches with zeros to match longest processings.

TRUNCATE

Truncate longer branches to match shortest processings.

ALLOW = 'allow'
ERROR = 'error'
PAD = 'pad'
TRUNCATE = 'truncate'
class nirs4all.operators.data.SourceIncompatibleStrategy(value)[source]

Bases: Enum

How to handle incompatible source shapes during stacking.

When using stack strategy with sources that have different feature dimensions or processing counts, this controls the resolution.

ERROR

Raise an error on incompatible shapes (default, strictest).

FLATTEN

Force 2D concatenation instead of stacking.

PAD

Pad shorter sources with zeros to match longest.

TRUNCATE

Truncate longer sources to match shortest.

ERROR = 'error'
FLATTEN = 'flatten'
PAD = 'pad'
TRUNCATE = 'truncate'
class nirs4all.operators.data.SourceMergeConfig(strategy: str = 'concat', sources: str | List[int | str] = 'all', on_incompatible: str = 'error', output_name: str = 'merged', preserve_source_info: bool = True)[source]

Bases: object

Configuration for merging multi-source dataset features.

This dataclass provides configuration for the merge_sources keyword, which combines features from multiple data sources (e.g., NIR, markers, Raman) into a unified feature space.

Unlike branch merging (merge), source merging operates on the data provenance dimension—combining features that originated from different sensors, instruments, or data modalities.

strategy

How to combine source features. - “concat” (default): Horizontal concatenation (2D result) - “stack”: Stack along new axis (3D result, requires uniform shapes) - “dict”: Keep as structured dictionary (for multi-input models)

Type:

str

sources

Which sources to include. - “all” (default): Include all available sources - List of source indices: [0, 1] for specific sources - List of source names: [“NIR”, “markers”] for named sources

Type:

str | List[int | str]

on_incompatible

How to handle incompatible shapes (for stack strategy). - “error” (default): Raise error if shapes don’t match - “flatten”: Fall back to 2D concat - “pad”: Pad shorter with zeros - “truncate”: Truncate longer to match shortest

Type:

str

output_name

Name for the merged output source (default: “merged”).

Type:

str

preserve_source_info

Whether to store source metadata for debugging.

Type:

bool

Example

>>> # Simple concatenation (default)
>>> {"merge_sources": "concat"}
>>>
>>> # Stack for 3D models (requires same feature count per source)
>>> {"merge_sources": {"strategy": "stack"}}
>>>
>>> # Selective sources with fallback on shape mismatch
>>> {"merge_sources": {
...     "strategy": "stack",
...     "sources": ["NIR", "MIR"],
...     "on_incompatible": "flatten"
... }}
>>>
>>> # Dict output for multi-head models
>>> {"merge_sources": {"strategy": "dict"}}
__post_init__()[source]

Validate configuration after initialization.

classmethod from_dict(data: Dict[str, Any]) SourceMergeConfig[source]

Create config from dictionary.

Parameters:

data – Dictionary representation.

Returns:

SourceMergeConfig instance.

get_incompatible_strategy() SourceIncompatibleStrategy[source]

Get the incompatible handling strategy as an enum.

Returns:

SourceIncompatibleStrategy enum value.

get_source_indices(available_sources: List[str]) List[int][source]

Resolve source specification to indices.

Parameters:

available_sources – List of available source names.

Returns:

List of source indices to include.

Raises:

ValueError – If a specified source is not found.

get_strategy() SourceMergeStrategy[source]

Get the merge strategy as an enum.

Returns:

SourceMergeStrategy enum value.

on_incompatible: str = 'error'
output_name: str = 'merged'
preserve_source_info: bool = True
sources: str | List[int | str] = 'all'
strategy: str = 'concat'
to_dict() Dict[str, Any][source]

Serialize configuration to dictionary.

Returns:

Dictionary representation for manifest storage.

class nirs4all.operators.data.SourceMergeStrategy(value)[source]

Bases: Enum

How to combine features from multiple data sources.

Used by the merge_sources keyword to control how multi-source datasets are unified into a single feature space.

CONCAT

Horizontal concatenation of all source features (default). Results in 2D array: (samples, sum_of_all_source_features). Different feature dimensions per source is expected.

STACK

Stack sources along a new axis to create 3D tensor. Results in 3D array: (samples, n_sources, n_features). Requires all sources to have the same feature dimension.

DICT

Keep sources as a structured dictionary. Results in Dict[str, ndarray] for multi-input models. Each source is accessible by name.

CONCAT = 'concat'
DICT = 'dict'
STACK = 'stack'
class nirs4all.operators.data.UnequelRepsStrategy(value)[source]

Bases: Enum

Strategy for handling samples with unequal repetition counts.

When samples have different numbers of repetitions, this controls how the transformation handles the mismatch.

ERROR

Raise an error if repetition counts differ (default, strictest).

PAD

Pad shorter groups with NaN/zeros to match the longest.

DROP

Drop samples that don’t have the expected repetition count.

TRUNCATE

Truncate all groups to the minimum repetition count.

DROP = 'drop'
ERROR = 'error'
PAD = 'pad'
TRUNCATE = 'truncate'