"""Merge operator configuration for branch and source merging.
This module provides configuration dataclasses and enums for the MergeController,
which handles combining branch outputs (features and/or predictions) and
exiting branch mode.
The merge operator is the core primitive for all branch combination operations.
It provides:
- Feature merging from branches (horizontal concatenation)
- Prediction merging with OOF reconstruction (data leakage prevention)
- Per-branch model selection and aggregation strategies
- Mixed merging (features from some branches, predictions from others)
Example:
>>> # Simple feature merge
>>> {"merge": "features"}
>>>
>>> # Prediction merge with OOF safety
>>> {"merge": "predictions"}
>>>
>>> # Mixed merge with per-branch control
>>> {"merge": {
... "predictions": [
... {"branch": 0, "select": "best", "metric": "rmse"},
... {"branch": 1, "aggregate": "mean"}
... ],
... "features": [2]
... }}
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import warnings
[docs]
class MergeMode(Enum):
"""What to merge from branches.
Attributes:
FEATURES: Merge feature matrices from branches.
PREDICTIONS: Merge model predictions from branches (with OOF reconstruction).
ALL: Merge both features and predictions from all branches.
"""
FEATURES = "features"
PREDICTIONS = "predictions"
ALL = "all"
[docs]
class BranchType(Enum):
"""Type of branch based on sample handling.
Attributes:
COPY: All branches see all samples (default branching behavior).
METADATA_PARTITIONER: Branches partition samples by metadata column.
SAMPLE_PARTITIONER: Branches partition samples by filter (e.g., outlier).
"""
COPY = "copy"
METADATA_PARTITIONER = "metadata_partitioner"
SAMPLE_PARTITIONER = "sample_partitioner"
[docs]
class DisjointSelectionCriterion(Enum):
"""Criterion for selecting top-N models in disjoint branch merge.
When branches have different model counts, we select top-N models
from each branch based on this criterion.
Attributes:
MSE: Select by lowest Mean Squared Error (default for regression).
RMSE: Select by lowest Root Mean Squared Error.
MAE: Select by lowest Mean Absolute Error.
R2: Select by highest R² score.
ORDER: Select first N in definition order (no ranking).
"""
MSE = "mse"
RMSE = "rmse"
MAE = "mae"
R2 = "r2"
ORDER = "order"
[docs]
@dataclass
class DisjointBranchInfo:
"""Information about a single branch in a disjoint merge.
Captures per-branch statistics and model selection details for
comprehensive merge metadata.
Attributes:
n_samples: Number of samples in this branch partition.
sample_ids: List of sample indices belonging to this branch.
n_models_original: Original number of models in the branch.
n_models_selected: Number of models selected for merge.
selected_models: List of selected model details with name, score, column.
dropped_models: List of dropped model details with name, score.
"""
n_samples: int
sample_ids: List[int]
n_models_original: int = 0
n_models_selected: int = 0
selected_models: List[Dict[str, Any]] = field(default_factory=list)
dropped_models: List[Dict[str, Any]] = field(default_factory=list)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"n_samples": self.n_samples,
"sample_ids": self.sample_ids,
"n_models_original": self.n_models_original,
"n_models_selected": self.n_models_selected,
"selected_models": self.selected_models,
"dropped_models": self.dropped_models,
}
[docs]
class SelectionStrategy(Enum):
"""How to select models within a branch for prediction merging.
When a branch contains multiple models, this controls which models'
predictions are included in the merge.
Attributes:
ALL: Include all models in the branch (default).
BEST: Single best model by specified metric.
TOP_K: Top K models by specified metric.
EXPLICIT: Explicit list of model names.
"""
ALL = "all"
BEST = "best"
TOP_K = "top_k"
EXPLICIT = "explicit"
[docs]
class AggregationStrategy(Enum):
"""How to aggregate predictions from selected models within a branch.
After model selection, this controls how the selected predictions
are combined into features for the merged output.
Attributes:
SEPARATE: Keep each model's predictions as separate features (default).
Results in N features (one per selected model).
MEAN: Simple average of all selected model predictions.
Results in 1 feature.
WEIGHTED_MEAN: Weighted average by validation score.
Results in 1 feature.
PROBA_MEAN: Average class probabilities (classification only).
Results in K features (one per class).
"""
SEPARATE = "separate"
MEAN = "mean"
WEIGHTED_MEAN = "weighted_mean"
PROBA_MEAN = "proba_mean"
[docs]
class ShapeMismatchStrategy(Enum):
"""How to handle shape mismatches during 3D feature merging.
This strategy only applies when using 3D layout for features, where
the number of processings must be aligned across branches. In 2D layout
(the default), features are simply flattened and concatenated horizontally,
so different feature dimensions across branches is expected and normal.
Example:
- Branch 0: (200 samples, 500 features) from MinMaxScaler
- Branch 1: (200 samples, 4 processings, 20 features) from multi-processing
In 2D layout: concatenates to (200, 500 + 4*20 = 580) - no error
In 3D layout: needs alignment strategy since processings differ
Attributes:
ERROR: Raise an error on shape mismatch (default, strictest).
ALLOW: Flatten to 2D and concatenate regardless of differences.
PAD: Pad shorter branches with zeros to match longest processings.
TRUNCATE: Truncate longer branches to match shortest processings.
"""
ERROR = "error"
ALLOW = "allow"
PAD = "pad"
TRUNCATE = "truncate"
[docs]
class SourceMergeStrategy(Enum):
"""How to combine features from multiple data sources.
Used by the `merge_sources` keyword to control how multi-source
datasets are unified into a single feature space.
Attributes:
CONCAT: Horizontal concatenation of all source features (default).
Results in 2D array: (samples, sum_of_all_source_features).
Different feature dimensions per source is expected.
STACK: Stack sources along a new axis to create 3D tensor.
Results in 3D array: (samples, n_sources, n_features).
Requires all sources to have the same feature dimension.
DICT: Keep sources as a structured dictionary.
Results in Dict[str, ndarray] for multi-input models.
Each source is accessible by name.
"""
CONCAT = "concat"
STACK = "stack"
DICT = "dict"
[docs]
class SourceIncompatibleStrategy(Enum):
"""How to handle incompatible source shapes during stacking.
When using `stack` strategy with sources that have different feature
dimensions or processing counts, this controls the resolution.
Attributes:
ERROR: Raise an error on incompatible shapes (default, strictest).
FLATTEN: Force 2D concatenation instead of stacking.
PAD: Pad shorter sources with zeros to match longest.
TRUNCATE: Truncate longer sources to match shortest.
"""
ERROR = "error"
FLATTEN = "flatten"
PAD = "pad"
TRUNCATE = "truncate"
[docs]
@dataclass
class BranchPredictionConfig:
"""Configuration for prediction collection from a single branch.
This dataclass specifies how to collect and process predictions
from a specific branch during merge operations.
Attributes:
branch: Branch index or name to collect from.
select: Model selection strategy.
- "all" (default): All models in branch
- "best": Single best model by metric
- {"top_k": N}: Top N models by metric
- ["ModelA", "ModelB"]: Explicit model names
metric: Metric for selection (rmse, mae, r2, accuracy, f1).
Default is task-appropriate (rmse for regression, accuracy for classification).
aggregate: How to combine predictions from selected models.
- "separate" (default): Each model is a separate feature
- "mean": Simple average of predictions
- "weighted_mean": Weight by validation score
- "proba_mean": Average class probabilities (classification)
weight_metric: Metric for weighted aggregation (default: same as `metric`).
proba: Use class probabilities instead of predictions (classification only).
sources: Source filter for multi-source datasets.
- "all" (default): Include all sources
- List of source indices or names
Example:
>>> # Best model from branch 0 by RMSE
>>> BranchPredictionConfig(branch=0, select="best", metric="rmse")
>>>
>>> # Top 3 models from branch 1, averaged
>>> BranchPredictionConfig(
... branch=1,
... select={"top_k": 3},
... metric="r2",
... aggregate="mean"
... )
>>>
>>> # Explicit models with weighted average
>>> BranchPredictionConfig(
... branch="spectral_path",
... select=["PLS", "RF"],
... aggregate="weighted_mean",
... weight_metric="r2"
... )
"""
branch: Union[int, str]
select: Union[str, Dict[str, Any], List[str]] = "all"
metric: Optional[str] = None
aggregate: str = "separate"
weight_metric: Optional[str] = None
proba: bool = False
sources: Union[str, List[Union[int, str]]] = "all"
[docs]
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate aggregate
valid_aggregates = ("separate", "mean", "weighted_mean", "proba_mean")
if self.aggregate not in valid_aggregates:
raise ValueError(
f"aggregate must be one of {valid_aggregates}, got '{self.aggregate}'"
)
# Validate select format
if isinstance(self.select, dict):
if "top_k" not in self.select:
raise ValueError(
"dict select must contain 'top_k' key, "
f"got keys: {list(self.select.keys())}"
)
top_k = self.select["top_k"]
if not isinstance(top_k, int) or top_k < 1:
raise ValueError(
f"top_k must be a positive integer, got {top_k}"
)
elif isinstance(self.select, str):
if self.select not in ("all", "best"):
raise ValueError(
f"string select must be 'all' or 'best', got '{self.select}'"
)
elif isinstance(self.select, list):
if not all(isinstance(s, str) for s in self.select):
raise ValueError(
"list select must contain only string model names"
)
if len(self.select) == 0:
raise ValueError("list select cannot be empty")
# Validate metric if provided
valid_metrics = ("rmse", "mae", "r2", "mse", "accuracy", "f1", "auc", "log_loss")
if self.metric is not None and self.metric not in valid_metrics:
raise ValueError(
f"metric must be one of {valid_metrics}, got '{self.metric}'"
)
if self.weight_metric is not None and self.weight_metric not in valid_metrics:
raise ValueError(
f"weight_metric must be one of {valid_metrics}, got '{self.weight_metric}'"
)
# Validate proba_mean requires proba=True
if self.aggregate == "proba_mean" and not self.proba:
warnings.warn(
"aggregate='proba_mean' requires proba=True. Setting proba=True automatically.",
UserWarning,
stacklevel=2
)
object.__setattr__(self, 'proba', True)
[docs]
def get_selection_strategy(self) -> SelectionStrategy:
"""Get the selection strategy enum for this configuration.
Returns:
SelectionStrategy enum value based on select field.
"""
if isinstance(self.select, str):
if self.select == "all":
return SelectionStrategy.ALL
elif self.select == "best":
return SelectionStrategy.BEST
elif isinstance(self.select, dict):
return SelectionStrategy.TOP_K
elif isinstance(self.select, list):
return SelectionStrategy.EXPLICIT
return SelectionStrategy.ALL
[docs]
def get_aggregation_strategy(self) -> AggregationStrategy:
"""Get the aggregation strategy enum for this configuration.
Returns:
AggregationStrategy enum value based on aggregate field.
"""
return AggregationStrategy(self.aggregate)
[docs]
@dataclass
class MergeConfig:
"""Configuration for branch merging operations.
This dataclass provides complete configuration for the MergeController,
controlling what data is collected from branches and how it is combined.
Attributes:
collect_features: Whether to collect features from branches.
feature_branches: Which branches to collect features from.
- "all" (default): All branches
- List of branch indices: [0, 2] for specific branches
collect_predictions: Whether to collect predictions from branches.
prediction_branches: Legacy simple mode: which branches for predictions.
Use `prediction_configs` for advanced per-branch control.
prediction_configs: Advanced per-branch prediction configuration.
Takes precedence over prediction_branches when set.
model_filter: Legacy: global model filter (simple mode).
List of model names to include.
use_proba: Legacy: global proba setting for classification.
include_original: Include pre-branch features in merged output.
When True, original features are prepended to merged features.
on_missing: How to handle missing branches or predictions.
- "error" (default): Raise an error
- "warn": Log warning and skip
- "skip": Silent skip
on_shape_mismatch: Reserved for 3D layout feature merging.
In 2D layout (default), features are flattened and concatenated
horizontally, so different feature dimensions is normal and this
parameter has no effect. For future 3D layout support:
- "error": Raise error if processings differ
- "allow": Flatten to 2D and concatenate
- "pad": Pad shorter processings with zeros
- "truncate": Truncate longer to match shortest
unsafe: If True, DISABLE OOF reconstruction for predictions.
⚠️ CAUSES DATA LEAKAGE - only for rapid prototyping.
output_as: Where to put merged output.
- "features" (default): Single concatenated feature matrix
- "sources": Each branch becomes a separate source
- "dict": Keep as structured dict for multi-head models
source_names: Custom names for output sources (when output_as="sources").
If not provided, uses "branch_0", "branch_1", etc.
Example:
>>> # Simple feature merge
>>> MergeConfig(collect_features=True)
>>>
>>> # Prediction merge with OOF
>>> MergeConfig(collect_predictions=True)
>>>
>>> # Mixed merge with per-branch control
>>> MergeConfig(
... collect_predictions=True,
... prediction_configs=[
... BranchPredictionConfig(branch=0, select="best"),
... BranchPredictionConfig(branch=1, aggregate="mean")
... ],
... collect_features=True,
... feature_branches=[2]
... )
>>>
>>> # Unsafe mode (with warning)
>>> MergeConfig(collect_predictions=True, unsafe=True)
>>>
>>> # Disjoint branch merge with n_columns override
>>> MergeConfig(
... collect_predictions=True,
... n_columns=2,
... select_by="mse"
... )
"""
collect_features: bool = False
feature_branches: Union[str, List[int]] = "all"
collect_predictions: bool = False
prediction_branches: Union[str, List[int]] = "all"
prediction_configs: Optional[List[BranchPredictionConfig]] = None
model_filter: Optional[List[str]] = None
use_proba: bool = False
include_original: bool = False
on_missing: str = "error"
on_shape_mismatch: str = "error"
unsafe: bool = False
output_as: str = "features" # Default to "features" for backward compatibility
source_names: Optional[List[str]] = None
# Disjoint sample branch merge options (Phase 2)
n_columns: Optional[int] = None # Force output column count for disjoint prediction merge
select_by: str = "mse" # Criterion for selecting top-N models (mse, rmse, mae, r2, order)
[docs]
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate on_missing
valid_on_missing = ("error", "warn", "skip")
if self.on_missing not in valid_on_missing:
raise ValueError(
f"on_missing must be one of {valid_on_missing}, got '{self.on_missing}'"
)
# Validate on_shape_mismatch
valid_shape_strategies = ("error", "allow", "pad", "truncate")
if self.on_shape_mismatch not in valid_shape_strategies:
raise ValueError(
f"on_shape_mismatch must be one of {valid_shape_strategies}, "
f"got '{self.on_shape_mismatch}'"
)
# Validate output_as
valid_output_as = ("features", "sources", "dict")
if self.output_as not in valid_output_as:
raise ValueError(
f"output_as must be one of {valid_output_as}, got '{self.output_as}'"
)
# Validate unsafe usage
if self.unsafe and self.collect_predictions:
warnings.warn(
"⚠️ MergeConfig: unsafe=True disables OOF reconstruction. "
"Training predictions will be used directly, causing DATA LEAKAGE. "
"Do NOT use for final model evaluation.",
UserWarning,
stacklevel=2
)
# Validate source_names
if self.source_names is not None and self.output_as != "sources":
warnings.warn(
"source_names is only used when output_as='sources'. "
"It will be ignored with current output_as setting.",
UserWarning,
stacklevel=2
)
# Validate n_columns
if self.n_columns is not None and self.n_columns < 1:
raise ValueError(
f"n_columns must be >= 1, got {self.n_columns}"
)
# Validate select_by
valid_select_by = ("mse", "rmse", "mae", "r2", "order")
if self.select_by not in valid_select_by:
raise ValueError(
f"select_by must be one of {valid_select_by}, got '{self.select_by}'"
)
[docs]
def get_selection_criterion(self) -> "DisjointSelectionCriterion":
"""Get the selection criterion enum for disjoint branch merging.
Returns:
DisjointSelectionCriterion enum value.
"""
return DisjointSelectionCriterion(self.select_by)
[docs]
def has_per_branch_config(self) -> bool:
"""Check if using advanced per-branch prediction configuration.
Returns:
True if prediction_configs is set and non-empty.
"""
return self.prediction_configs is not None and len(self.prediction_configs) > 0
[docs]
def get_prediction_configs(
self,
n_branches: int
) -> List[BranchPredictionConfig]:
"""Get prediction configurations, normalizing legacy format if needed.
Converts legacy simple mode (prediction_branches + model_filter + use_proba)
to per-branch configurations for uniform processing.
Args:
n_branches: Total number of branches available.
Returns:
List of BranchPredictionConfig for each branch to collect from.
"""
# If advanced config is set, use it directly
if self.has_per_branch_config():
return self.prediction_configs
# Convert legacy format to per-branch configs
# Resolve branch indices
if self.prediction_branches == "all":
branch_indices = list(range(n_branches))
else:
branch_indices = self.prediction_branches
configs = []
for branch_idx in branch_indices:
config = BranchPredictionConfig(
branch=branch_idx,
select=self.model_filter if self.model_filter else "all",
proba=self.use_proba,
aggregate="separate"
)
configs.append(config)
return configs
[docs]
def get_feature_branches(self, n_branches: int) -> List[int]:
"""Get list of branch indices to collect features from.
Args:
n_branches: Total number of branches available.
Returns:
List of branch indices.
"""
if self.feature_branches == "all":
return list(range(n_branches))
return list(self.feature_branches)
[docs]
def get_merge_mode(self) -> MergeMode:
"""Determine the merge mode based on configuration.
Returns:
MergeMode enum value.
"""
if self.collect_features and self.collect_predictions:
return MergeMode.ALL
elif self.collect_features:
return MergeMode.FEATURES
elif self.collect_predictions:
return MergeMode.PREDICTIONS
else:
raise ValueError(
"Invalid MergeConfig: neither collect_features nor collect_predictions is True"
)
[docs]
def get_shape_mismatch_strategy(self) -> ShapeMismatchStrategy:
"""Get the shape mismatch strategy enum.
Returns:
ShapeMismatchStrategy enum value.
"""
return ShapeMismatchStrategy(self.on_shape_mismatch)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Serialize merge configuration to a dictionary.
Used for saving merge configuration to manifest for reproducibility
in prediction mode and bundle export.
Returns:
Dictionary representation suitable for YAML/JSON serialization.
"""
result = {
"collect_features": self.collect_features,
"collect_predictions": self.collect_predictions,
"include_original": self.include_original,
"on_missing": self.on_missing,
"on_shape_mismatch": self.on_shape_mismatch,
"unsafe": self.unsafe,
"output_as": self.output_as,
}
if self.feature_branches != "all":
result["feature_branches"] = self.feature_branches
if self.prediction_branches != "all":
result["prediction_branches"] = self.prediction_branches
if self.prediction_configs:
result["prediction_configs"] = [
{
"branch": pc.branch,
"select": pc.select,
"metric": pc.metric,
"aggregate": pc.aggregate,
"weight_metric": pc.weight_metric,
"proba": pc.proba,
"sources": pc.sources,
}
for pc in self.prediction_configs
]
if self.model_filter:
result["model_filter"] = self.model_filter
if self.use_proba:
result["use_proba"] = self.use_proba
if self.source_names:
result["source_names"] = self.source_names
# Disjoint branch merge options
if self.n_columns is not None:
result["n_columns"] = self.n_columns
if self.select_by != "mse": # Only serialize non-default
result["select_by"] = self.select_by
return result
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MergeConfig":
"""Create MergeConfig from a dictionary.
Used for loading merge configuration from manifest in prediction mode.
Args:
data: Dictionary representation of merge configuration.
Returns:
MergeConfig instance.
"""
prediction_configs = None
if "prediction_configs" in data:
prediction_configs = [
BranchPredictionConfig(
branch=pc["branch"],
select=pc.get("select", "all"),
metric=pc.get("metric"),
aggregate=pc.get("aggregate", "separate"),
weight_metric=pc.get("weight_metric"),
proba=pc.get("proba", False),
sources=pc.get("sources", "all"),
)
for pc in data["prediction_configs"]
]
return cls(
collect_features=data.get("collect_features", False),
feature_branches=data.get("feature_branches", "all"),
collect_predictions=data.get("collect_predictions", False),
prediction_branches=data.get("prediction_branches", "all"),
prediction_configs=prediction_configs,
model_filter=data.get("model_filter"),
use_proba=data.get("use_proba", False),
include_original=data.get("include_original", False),
on_missing=data.get("on_missing", "error"),
on_shape_mismatch=data.get("on_shape_mismatch", "error"),
unsafe=data.get("unsafe", False),
output_as=data.get("output_as", "sources"),
source_names=data.get("source_names"),
n_columns=data.get("n_columns"),
select_by=data.get("select_by", "mse"),
)
[docs]
@dataclass
class SourceMergeConfig:
"""Configuration for merging multi-source dataset features.
This dataclass provides configuration for the `merge_sources` keyword,
which combines features from multiple data sources (e.g., NIR, markers,
Raman) into a unified feature space.
Unlike branch merging (`merge`), source merging operates on the data
provenance dimension—combining features that originated from different
sensors, instruments, or data modalities.
Attributes:
strategy: How to combine source features.
- "concat" (default): Horizontal concatenation (2D result)
- "stack": Stack along new axis (3D result, requires uniform shapes)
- "dict": Keep as structured dictionary (for multi-input models)
sources: Which sources to include.
- "all" (default): Include all available sources
- List of source indices: [0, 1] for specific sources
- List of source names: ["NIR", "markers"] for named sources
on_incompatible: How to handle incompatible shapes (for stack strategy).
- "error" (default): Raise error if shapes don't match
- "flatten": Fall back to 2D concat
- "pad": Pad shorter with zeros
- "truncate": Truncate longer to match shortest
output_name: Name for the merged output source (default: "merged").
preserve_source_info: Whether to store source metadata for debugging.
Example:
>>> # Simple concatenation (default)
>>> {"merge_sources": "concat"}
>>>
>>> # Stack for 3D models (requires same feature count per source)
>>> {"merge_sources": {"strategy": "stack"}}
>>>
>>> # Selective sources with fallback on shape mismatch
>>> {"merge_sources": {
... "strategy": "stack",
... "sources": ["NIR", "MIR"],
... "on_incompatible": "flatten"
... }}
>>>
>>> # Dict output for multi-head models
>>> {"merge_sources": {"strategy": "dict"}}
"""
strategy: str = "concat"
sources: Union[str, List[Union[int, str]]] = "all"
on_incompatible: str = "error"
output_name: str = "merged"
preserve_source_info: bool = True
[docs]
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate strategy
valid_strategies = ("concat", "stack", "dict")
if self.strategy not in valid_strategies:
raise ValueError(
f"strategy must be one of {valid_strategies}, got '{self.strategy}'"
)
# Validate on_incompatible
valid_incompatible = ("error", "flatten", "pad", "truncate")
if self.on_incompatible not in valid_incompatible:
raise ValueError(
f"on_incompatible must be one of {valid_incompatible}, "
f"got '{self.on_incompatible}'"
)
# Validate sources
if isinstance(self.sources, list):
if len(self.sources) == 0:
raise ValueError("sources list cannot be empty")
[docs]
def get_strategy(self) -> SourceMergeStrategy:
"""Get the merge strategy as an enum.
Returns:
SourceMergeStrategy enum value.
"""
return SourceMergeStrategy(self.strategy)
[docs]
def get_incompatible_strategy(self) -> SourceIncompatibleStrategy:
"""Get the incompatible handling strategy as an enum.
Returns:
SourceIncompatibleStrategy enum value.
"""
return SourceIncompatibleStrategy(self.on_incompatible)
[docs]
def get_source_indices(self, available_sources: List[str]) -> List[int]:
"""Resolve source specification to indices.
Args:
available_sources: List of available source names.
Returns:
List of source indices to include.
Raises:
ValueError: If a specified source is not found.
"""
if self.sources == "all":
return list(range(len(available_sources)))
indices = []
for source in self.sources:
if isinstance(source, int):
if source < 0 or source >= len(available_sources):
raise ValueError(
f"Source index {source} out of range. "
f"Available: 0-{len(available_sources) - 1}. "
f"[Error: MERGE-E031]"
)
indices.append(source)
elif isinstance(source, str):
if source not in available_sources:
raise ValueError(
f"Source name '{source}' not found. "
f"Available: {available_sources}. "
f"[Error: MERGE-E031]"
)
indices.append(available_sources.index(source))
else:
raise ValueError(
f"Source must be int or str, got {type(source).__name__}"
)
return indices
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Serialize configuration to dictionary.
Returns:
Dictionary representation for manifest storage.
"""
result = {
"strategy": self.strategy,
"on_incompatible": self.on_incompatible,
"output_name": self.output_name,
"preserve_source_info": self.preserve_source_info,
}
if self.sources != "all":
result["sources"] = self.sources
return result
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SourceMergeConfig":
"""Create config from dictionary.
Args:
data: Dictionary representation.
Returns:
SourceMergeConfig instance.
"""
return cls(
strategy=data.get("strategy", "concat"),
sources=data.get("sources", "all"),
on_incompatible=data.get("on_incompatible", "error"),
output_name=data.get("output_name", "merged"),
preserve_source_info=data.get("preserve_source_info", True),
)
[docs]
@dataclass
class SourceBranchConfig:
"""Configuration for source branching operations.
This dataclass provides configuration for the `source_branch` keyword,
which creates per-source pipeline execution paths. Each source in a
multi-source dataset gets its own independent processing pipeline.
Unlike regular branching (`branch`), which creates parallel paths that
all process the same data, source branching assigns each source to a
specific processing pipeline based on its name or index.
Attributes:
source_pipelines: Mapping of source names/indices to their pipeline steps.
- Dict[str, List]: Named sources to steps mapping
- Dict[int, List]: Source indices to steps mapping
- "auto": Apply same steps to all sources independently
default_pipeline: Default pipeline for sources not explicitly specified.
Applied when a source is not listed in source_pipelines.
If None, unspecified sources are passed through unchanged.
merge_after: Whether to automatically merge sources after branching.
- True (default): Automatically call merge_sources after
- False: Keep sources separate (user must merge manually)
merge_strategy: Strategy for auto-merge (when merge_after=True).
- "concat" (default): Horizontal concatenation
- "stack": Stack along source axis
- "dict": Keep as dictionary
Example:
>>> # Different preprocessing per source
>>> {"source_branch": {
... "NIR": [SNV(), SavitzkyGolay()],
... "markers": [VarianceThreshold(), MinMaxScaler()],
... "Raman": [BaselineCorrection(), StandardScaler()]
... }}
>>>
>>> # Source branching with default fallback
>>> {"source_branch": {
... "NIR": [SNV()],
... "_default_": [MinMaxScaler()] # Applied to other sources
... }}
>>>
>>> # Automatic same-preprocessing per source (isolates sources)
>>> {"source_branch": "auto"}
>>>
>>> # Source branching without auto-merge
>>> {"source_branch": {
... "NIR": [SNV()],
... "markers": [StandardScaler()],
... "_merge_after_": False # Disable auto-merge
... }}
"""
source_pipelines: Union[str, Dict[Union[str, int], List[Any]]] = field(
default_factory=dict
)
default_pipeline: Optional[List[Any]] = None
merge_after: bool = True
merge_strategy: str = "concat"
[docs]
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate merge_strategy
valid_strategies = ("concat", "stack", "dict")
if self.merge_strategy not in valid_strategies:
raise ValueError(
f"merge_strategy must be one of {valid_strategies}, "
f"got '{self.merge_strategy}'"
)
# Validate source_pipelines format
if isinstance(self.source_pipelines, str):
if self.source_pipelines != "auto":
raise ValueError(
f"string source_pipelines must be 'auto', got '{self.source_pipelines}'"
)
elif isinstance(self.source_pipelines, dict):
# Extract special keys
if "_default_" in self.source_pipelines:
self.default_pipeline = self.source_pipelines.pop("_default_")
if "_merge_after_" in self.source_pipelines:
self.merge_after = self.source_pipelines.pop("_merge_after_")
if "_merge_strategy_" in self.source_pipelines:
self.merge_strategy = self.source_pipelines.pop("_merge_strategy_")
# Validate remaining keys are valid source references
for key in self.source_pipelines.keys():
if not isinstance(key, (str, int)):
raise ValueError(
f"source_pipelines keys must be str or int, got {type(key).__name__}"
)
[docs]
def is_auto_mode(self) -> bool:
"""Check if using automatic source branching.
Returns:
True if source_pipelines is "auto".
"""
return self.source_pipelines == "auto"
[docs]
def get_pipeline_for_source(
self,
source_name: str,
source_index: int
) -> Optional[List[Any]]:
"""Get pipeline steps for a specific source.
Args:
source_name: Name of the source.
source_index: Index of the source.
Returns:
List of pipeline steps for this source, or None if passthrough.
"""
if self.is_auto_mode():
# Auto mode: return empty list (passthrough with isolation)
return []
# Check by name first, then by index (both int and string form)
if isinstance(self.source_pipelines, dict):
# Check by source name (e.g., "NIR", "markers")
if source_name in self.source_pipelines:
return self.source_pipelines[source_name]
# Check by integer index (e.g., 0, 1, 2)
if source_index in self.source_pipelines:
return self.source_pipelines[source_index]
# Check by string index (e.g., "0", "1", "2") - for list-indexed syntax
str_index = str(source_index)
if str_index in self.source_pipelines:
return self.source_pipelines[str_index]
# Fall back to default
return self.default_pipeline
[docs]
def get_all_source_mappings(
self,
available_sources: List[str]
) -> Dict[str, List[Any]]:
"""Get pipeline mapping for all available sources.
Args:
available_sources: List of available source names.
Returns:
Dict mapping source names to their pipeline steps.
"""
result = {}
if self.is_auto_mode():
# Auto mode: each source gets empty pipeline (isolation only)
for source in available_sources:
result[source] = []
elif isinstance(self.source_pipelines, dict):
for idx, source in enumerate(available_sources):
pipeline = self.get_pipeline_for_source(source, idx)
if pipeline is not None:
result[source] = pipeline
elif self.default_pipeline is not None:
result[source] = self.default_pipeline
else:
# No pipeline specified and no default: passthrough
result[source] = []
return result
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Serialize configuration to dictionary.
Returns:
Dictionary representation for manifest storage.
"""
result = {
"merge_after": self.merge_after,
"merge_strategy": self.merge_strategy,
}
if self.is_auto_mode():
result["source_pipelines"] = "auto"
else:
# Serialize pipeline references (not the actual objects)
result["source_pipelines"] = {
str(k): "..." for k in self.source_pipelines.keys()
}
if self.default_pipeline is not None:
result["default_pipeline"] = "..."
return result
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SourceBranchConfig":
"""Create config from dictionary.
Note: This is primarily for metadata reconstruction. The actual
pipeline steps must be restored from the manifest/artifacts.
Args:
data: Dictionary representation.
Returns:
SourceBranchConfig instance (with placeholder pipelines).
"""
source_pipelines = data.get("source_pipelines", {})
if isinstance(source_pipelines, str) and source_pipelines == "auto":
source_pipelines = "auto"
else:
# Placeholder for actual pipeline reconstruction
source_pipelines = {}
return cls(
source_pipelines=source_pipelines,
default_pipeline=None, # Must be reconstructed
merge_after=data.get("merge_after", True),
merge_strategy=data.get("merge_strategy", "concat"),
)