"""
Branch Controller for pipeline branching.
This controller enables splitting a pipeline into multiple parallel sub-pipelines
("branches"), each with its own preprocessing context (X transformations, Y processing),
while sharing common upstream state (splits, initial preprocessing).
Steps declared after a branch block execute on each branch independently.
V3 improvements:
- Uses trace_recorder.enter_branch() / exit_branch() for automatic branch path tracking
- Records each branch substep individually in the execution trace
- Builds proper operator chains for artifact identification
Example:
>>> pipeline = [
... ShuffleSplit(n_splits=5),
... {"branch": [
... [SNV(), PCA(n_components=10)],
... [MSC(), FirstDerivative()],
... ]},
... PLSRegression(n_components=5), # Runs on BOTH branches
... ]
Generator syntax is also supported:
>>> pipeline = [
... ShuffleSplit(n_splits=3),
... {"branch": {"_or_": [SNV(), MSC(), FirstDerivative()]}}, # 3 branches
... PLSRegression(n_components=5),
... ]
"""
import copy
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.config.generator import (
expand_spec,
is_generator_node,
)
from nirs4all.pipeline.execution.result import StepOutput
logger = get_logger(__name__)
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.pipeline.steps.parser import ParsedStep
[docs]
@register_controller
class BranchController(OperatorController):
"""Controller for pipeline branching.
Implements the branching mechanism that allows multiple preprocessing
chains to be evaluated independently within a single pipeline execution.
Key behaviors:
- Creates independent context copies for each branch
- Executes branch steps sequentially within each branch
- Stores branch contexts in context.custom["branch_contexts"]
- Post-branch steps iterate over all branch contexts
Attributes:
priority: Controller priority (lower = higher priority). Set to 5
to execute before most other controllers.
"""
priority = 5 # High priority to catch branch keyword early
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Check if the step matches the branch controller.
Args:
step: Original step configuration
operator: Deserialized operator (may be list of branch definitions)
keyword: Step keyword
Returns:
True if keyword is "branch"
"""
return keyword == "branch"
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Branch controller supports multi-source datasets."""
return True
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Branch controller should execute in prediction mode to reconstruct branches."""
return True
[docs]
def execute(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute the branch step with V3 chain tracking.
Creates independent contexts for each branch, executes branch-specific
steps, and stores branch contexts for post-branch iteration.
In predict/explain mode, only executes the target branch specified in
runtime_context.target_model.branch_id for efficiency.
V3 improvements:
- Uses trace_recorder.enter_branch() / exit_branch() for branch path tracking
- Records each substep individually for complete trace fidelity
- Builds proper operator chains for artifact identification
Args:
step_info: Parsed step containing branch definitions
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput with collected artifacts)
"""
# Get branch definitions from step
branch_defs = self._parse_branch_definitions(step_info)
if not branch_defs:
logger.warning("No branch definitions found, skipping branch step")
return context, StepOutput()
n_branches = len(branch_defs)
# V3: Start branch step recording in trace
recorder = runtime_context.trace_recorder
if recorder is not None:
recorder.start_branch_step(
step_index=runtime_context.step_number,
branch_count=n_branches,
operator_config={"branch_definitions": len(branch_defs)},
)
# In predict/explain mode, filter to only the target branch
target_branch_id = None
target_branch_name = None
if mode in ("predict", "explain") and hasattr(runtime_context, 'target_model') and runtime_context.target_model:
target_branch_id = runtime_context.target_model.get("branch_id")
target_branch_name = runtime_context.target_model.get("branch_name")
if target_branch_id is not None:
# Filter branch_defs to only the target branch
if target_branch_id < len(branch_defs):
branch_defs = [branch_defs[target_branch_id]]
logger.info(f"Predict mode: executing only branch {target_branch_id} ({target_branch_name or 'unnamed'})")
else:
raise ValueError(
f"Target branch_id={target_branch_id} not found in pipeline. "
f"Pipeline has {n_branches} branches (0-{n_branches-1}). "
f"The model may have been trained with a different branch configuration."
)
else:
logger.info(f"Creating {n_branches} branches (predict mode, no target branch specified)")
else:
logger.info(f"Creating {n_branches} branches")
# Store the initial context as a snapshot point
initial_context = context.copy()
initial_processing = copy.deepcopy(context.selector.processing)
# Snapshot the dataset's feature state before branching
# This is necessary because branches modify the shared dataset
initial_features_snapshot = self._snapshot_features(dataset)
# V3: Snapshot the chain state before branching
initial_chain = recorder.current_chain() if recorder else None
# Initialize list to collect branch contexts
branch_contexts: List[Dict[str, Any]] = []
all_artifacts = []
# Execute each branch
# In predict mode with filtered branches, we need to preserve original branch_id
for idx, branch_def in enumerate(branch_defs):
# Use original branch_id if we're in predict mode and have filtered
if target_branch_id is not None:
branch_id = target_branch_id
else:
branch_id = idx
branch_name = branch_def.get("name", f"branch_{branch_id}")
branch_steps = branch_def.get("steps", [])
logger.info(f" Branch {branch_id}: {branch_name}")
# V3: Enter branch context in trace recorder
if recorder is not None:
recorder.enter_branch(branch_id)
# Reset chain to initial state for this branch
if initial_chain is not None:
recorder.reset_chain_to(initial_chain)
# Create isolated context for this branch
branch_context = initial_context.copy()
# Build branch_path by appending to parent's branch_path
parent_branch_path = context.selector.branch_path or []
new_branch_path = parent_branch_path + [branch_id]
branch_context.selector = branch_context.selector.with_branch(
branch_id=branch_id,
branch_name=branch_name,
branch_path=new_branch_path
)
# Reset processing to initial state for this branch
branch_context.selector.processing = copy.deepcopy(initial_processing)
# Restore dataset features to initial state for this branch
# This ensures each branch starts from the same feature state
self._restore_features(dataset, initial_features_snapshot)
# Reset artifact load counter for this branch
# Each branch has its own set of artifacts, so the positional counter
# must restart at 0 for each branch
if runtime_context:
runtime_context.artifact_load_counter = {}
# In predict/explain mode, load branch-specific binaries
branch_binaries = loaded_binaries
if mode in ("predict", "explain") and runtime_context.artifact_loader:
branch_binaries = runtime_context.artifact_loader.get_step_binaries(
runtime_context.step_number, branch_id=branch_id
)
if not branch_binaries:
# Fall back to non-branch binaries if no branch-specific ones exist
branch_binaries = loaded_binaries
# Execute branch steps sequentially
# V3: Each substep is recorded individually in the execution trace
for substep_idx, substep in enumerate(branch_steps):
if runtime_context.step_runner:
runtime_context.substep_number = substep_idx
# Record substep in trace before execution
if recorder is not None:
op_type, op_class = self._extract_substep_info(substep)
recorder.start_branch_substep(
parent_step_index=runtime_context.step_number,
branch_id=branch_id,
operator_type=op_type,
operator_class=op_class,
substep_index=substep_idx,
branch_name=branch_name,
)
# Record input shapes before substep execution
self._record_dataset_shapes(
dataset, branch_context, runtime_context, is_input=True
)
result = runtime_context.step_runner.execute(
step=substep,
dataset=dataset,
context=branch_context,
runtime_context=runtime_context,
loaded_binaries=branch_binaries,
prediction_store=prediction_store
)
# End substep recording with output shapes
if recorder is not None:
# Record output shapes after substep execution
self._record_dataset_shapes(
dataset, result.updated_context, runtime_context, is_input=False
)
# Check if this substep was a model to properly set model_step_index
is_model = op_type in ("model", "meta_model")
recorder.end_step(is_model=is_model)
branch_context = result.updated_context
all_artifacts.extend(result.artifacts)
# V3: Snapshot the chain state BEFORE exiting branch context
# This captures the correct operator chain for this branch, which post-branch
# steps (e.g., MetaModel) need to build their artifact IDs correctly
branch_chain_snapshot = recorder.current_chain() if recorder else None
# V3: Exit branch context in trace recorder
if recorder is not None:
recorder.exit_branch()
# Snapshot features AFTER branch processing completes
# This captures the feature state produced by this branch's transformers
# Post-branch steps (e.g., model) need this to use correct features per branch
branch_features_snapshot = self._snapshot_features(dataset)
# Store the final context for this branch
branch_contexts.append({
"branch_id": branch_id,
"name": branch_name,
"context": branch_context,
"generator_choice": branch_def.get("generator_choice"),
"features_snapshot": branch_features_snapshot,
"chain_snapshot": branch_chain_snapshot, # V3: Chain for post-branch steps
})
logger.success(f" Branch {branch_id} ({branch_name}) completed")
# V3: End branch step in trace
if recorder is not None:
recorder.end_step()
# Store branch contexts in custom dict for post-branch iteration
# Merge with any existing branch contexts (for nested branches)
existing_branches = context.custom.get("branch_contexts", [])
if existing_branches:
# Nested branching: multiply existing branches with new ones
new_branch_contexts = self._multiply_branch_contexts(
existing_branches, branch_contexts
)
else:
new_branch_contexts = branch_contexts
# Update context with branch contexts
# Use the last branch's context as the "current" context
# but store all contexts for post-branch iteration
result_context = context.copy()
result_context.custom["branch_contexts"] = new_branch_contexts
# Mark that we are in branching mode
result_context.custom["in_branch_mode"] = True
# Collect generator choices from branches for serialization
branch_generator_choices = [
{"branch": branch_def.get("generator_choice")}
for branch_def in branch_defs
if branch_def.get("generator_choice") is not None
]
logger.success(f"Branch step completed with {len(new_branch_contexts)} branch(es)")
return result_context, StepOutput(
artifacts=all_artifacts,
metadata={
"branch_count": len(new_branch_contexts),
"branch_generator_choices": branch_generator_choices
}
)
def _parse_branch_definitions(
self,
step_info: "ParsedStep"
) -> List[Dict[str, Any]]:
"""Parse branch definitions from step configuration.
Supports multiple syntaxes:
- List of lists: [[step1, step2], [step3]]
- Dict with names: {"snv_pca": [SNV(), PCA()], "msc": [MSC()]}
- List of dicts: [{"name": "a", "steps": [...]}, ...]
- Generator syntax: {"_or_": [SNV(), MSC()]} or {"_range_": [...]}
Args:
step_info: Parsed step containing branch definitions
Returns:
Normalized list of branch definitions with 'name' and 'steps'
"""
# Get the raw branch definition from original step
raw_def = step_info.original_step.get("branch", [])
if not raw_def:
return []
# Case 0: Generator syntax - expand before processing
if isinstance(raw_def, dict) and is_generator_node(raw_def):
return self._expand_generator_branches(raw_def)
# Case 1: Dict with named branches {"name": [steps], ...}
if isinstance(raw_def, dict):
# Check if any value contains generator syntax
expanded_branches = []
for name, steps in raw_def.items():
if isinstance(steps, dict) and is_generator_node(steps):
# Expand generator within named branch
expanded_steps = expand_spec(steps)
for i, exp_step in enumerate(expanded_steps):
branch_name = f"{name}_{self._generate_step_name(exp_step, i)}"
expanded_branches.append({
"name": branch_name,
"steps": exp_step if isinstance(exp_step, list) else [exp_step],
"generator_choice": exp_step
})
else:
expanded_branches.append({
"name": name,
"steps": steps if isinstance(steps, list) else [steps]
})
return expanded_branches
# Case 2: List format
if isinstance(raw_def, list):
result = []
for i, item in enumerate(raw_def):
# Sub-case 2a: Dict with explicit name and steps
if isinstance(item, dict) and "steps" in item:
steps = item["steps"]
# Check for generator in steps
if isinstance(steps, dict) and is_generator_node(steps):
expanded_steps = expand_spec(steps)
for j, exp_step in enumerate(expanded_steps):
result.append({
"name": f"{item.get('name', f'branch_{i}')}_{j}",
"steps": exp_step if isinstance(exp_step, list) else [exp_step],
"generator_choice": exp_step
})
else:
result.append({
"name": item.get("name", f"branch_{i}"),
"steps": steps
})
# Sub-case 2b: Dict with generator syntax (inside list)
elif isinstance(item, dict) and is_generator_node(item):
expanded = expand_spec(item)
for j, exp_item in enumerate(expanded):
branch_name = self._generate_step_name(exp_item, i * 100 + j)
result.append({
"name": branch_name,
"steps": exp_item if isinstance(exp_item, list) else [exp_item],
"generator_choice": exp_item
})
# Sub-case 2c: List of steps (anonymous branch)
elif isinstance(item, list):
# Check if the list contains generator nodes
expanded_list = self._expand_list_with_generators(item)
if len(expanded_list) > 1:
for j, exp_item in enumerate(expanded_list):
result.append({
"name": f"branch_{i}_{j}",
"steps": exp_item,
"generator_choice": exp_item
})
else:
result.append({
"name": f"branch_{i}",
"steps": expanded_list[0] if expanded_list else item
})
# Sub-case 2d: Single step (wrap in list)
else:
result.append({
"name": f"branch_{i}",
"steps": [item]
})
return result
# Fallback: treat as single branch with single step
return [{"name": "branch_0", "steps": [raw_def]}]
def _expand_generator_branches(
self,
generator_node: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Expand a generator node into branch definitions.
Handles:
- {"_or_": [SNV(), MSC(), D1()]} -> 3 branches
- {"_or_": [[SNV(), PCA()], [MSC()]]} -> 2 branches with multi-step
- {"_range_": [5, 15, 5], "param": "n_components", "model": PLS} -> 3 branches
Args:
generator_node: Dict with generator keywords (_or_, _range_, etc.)
Returns:
List of branch definitions with 'name' and 'steps'
"""
expanded = expand_spec(generator_node)
result = []
for i, item in enumerate(expanded):
branch_name = self._generate_step_name(item, i)
# Ensure steps is always a list
if isinstance(item, list):
steps = item
else:
steps = [item]
result.append({
"name": branch_name,
"steps": steps,
"generator_choice": item
})
return result
def _expand_list_with_generators(
self,
items: List[Any]
) -> List[List[Any]]:
"""Expand a list that may contain generator nodes.
Each generator node in the list is expanded, and the Cartesian product
of all expansions is returned.
Args:
items: List of steps, some of which may be generator nodes
Returns:
List of expanded step lists
"""
from itertools import product
expanded_items = []
for item in items:
if isinstance(item, dict) and is_generator_node(item):
expanded_items.append(expand_spec(item))
else:
expanded_items.append([item])
# Compute Cartesian product
result = []
for combo in product(*expanded_items):
result.append(list(combo))
return result if result else [items]
def _generate_step_name(
self,
step: Any,
index: int
) -> str:
"""Generate a human-readable name for a step or list of steps.
Args:
step: A step dict, class instance, or list of steps
index: Fallback index if name cannot be extracted
Returns:
A descriptive branch name
"""
if isinstance(step, list):
# Multiple steps - combine names
names = [self._get_single_step_name(s) for s in step]
# Filter out None/empty names
names = [n for n in names if n]
if names:
return "_".join(names[:3]) # Limit to first 3 steps
return f"branch_{index}"
return self._get_single_step_name(step) or f"branch_{index}"
def _get_single_step_name(
self,
step: Any
) -> Optional[str]:
"""Extract a short name from a single step.
Args:
step: A step configuration (dict, string, etc.)
Returns:
Short name or None if not extractable
"""
if step is None:
return None
if isinstance(step, str):
return step
if isinstance(step, dict):
# Check for 'name' key first (explicit naming)
if "name" in step:
return step["name"]
# Check for 'class' key (serialized format)
if "class" in step:
class_name = step["class"]
# Extract short class name from full path
if isinstance(class_name, str) and "." in class_name:
return class_name.split(".")[-1]
return str(class_name).split(".")[-1].replace("'>", "")
# Check for model key
if "model" in step:
return self._get_single_step_name(step["model"])
# Check for preprocessing key
if "preprocessing" in step:
return self._get_single_step_name(step["preprocessing"])
# Use first key as hint
keys = [k for k in step.keys() if not k.startswith("_")]
if keys:
return keys[0]
# For class instances, try to get class name
if hasattr(step, "__class__"):
return step.__class__.__name__
return None
def _multiply_branch_contexts(
self,
existing: List[Dict[str, Any]],
new: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Multiply existing branch contexts with new branches (for nesting).
Creates Cartesian product of branch contexts for nested branching.
Uses branch_path for tracking the full hierarchy.
Args:
existing: List of existing branch context dicts
new: List of new branch context dicts
Returns:
Combined list of branch contexts with hierarchical branch_path
"""
result = []
flattened_id = 0
for parent in existing:
parent_id = parent["branch_id"]
parent_name = parent["name"]
parent_context = parent["context"]
parent_branch_path = parent_context.selector.branch_path or [parent_id]
for child in new:
child_id = child["branch_id"]
child_name = child["name"]
child_context = child["context"]
# Create combined context
combined_context = child_context.copy()
# Build nested branch_path: parent_path + child_id
combined_branch_path = parent_branch_path + [child_id]
combined_context.selector = combined_context.selector.with_branch(
branch_id=flattened_id, # Keep flattened ID for backward compat
branch_name=f"{parent_name}_{child_name}",
branch_path=combined_branch_path
)
result.append({
"branch_id": flattened_id,
"name": f"{parent_name}_{child_name}",
"context": combined_context,
"parent_branch_id": parent_id,
"child_branch_id": child_id,
"branch_path": combined_branch_path
})
flattened_id += 1
return result
def _snapshot_features(self, dataset: "SpectroDataset") -> List[Any]:
"""Create a deep copy of the dataset's feature sources.
This is used to restore features to their initial state before each branch,
ensuring branches operate on independent copies of the data.
Args:
dataset: The dataset to snapshot
Returns:
A deep copy of the feature sources list
"""
return copy.deepcopy(dataset._features.sources)
def _restore_features(
self,
dataset: "SpectroDataset",
snapshot: List[Any]
) -> None:
"""Restore the dataset's feature sources from a snapshot.
Args:
dataset: The dataset to restore
snapshot: The previously snapshotted feature sources
"""
dataset._features.sources = copy.deepcopy(snapshot)
def _record_dataset_shapes(
self,
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
is_input: bool = True
) -> None:
"""Record dataset shapes to the execution trace for branch substeps.
Captures both 2D layout shape and 3D per-source feature shapes.
Args:
dataset: The dataset to measure
context: Execution context with selector
runtime_context: Runtime context with trace recorder
is_input: True to record input shapes, False for output shapes
"""
try:
# Get 2D layout shape (samples × features)
X_2d = dataset.x(context.selector, layout="2d", include_excluded=False)
if isinstance(X_2d, list):
# Multi-source with concat
layout_shape = (X_2d[0].shape[0], sum(x.shape[1] for x in X_2d))
else:
layout_shape = X_2d.shape
# Get 3D per-source shapes (samples × processings × features)
X_3d = dataset.x(context.selector, layout="3d", concat_source=False, include_excluded=False)
if not isinstance(X_3d, list):
X_3d = [X_3d]
features_shapes = [x.shape for x in X_3d]
# Record to trace via runtime_context
if is_input:
runtime_context.record_input_shapes(
input_shape=layout_shape,
features_shape=features_shapes
)
else:
runtime_context.record_output_shapes(
output_shape=layout_shape,
features_shape=features_shapes
)
except Exception:
# Shape recording is non-critical, don't fail the step
pass
def _extract_substep_info(self, step: Any) -> tuple:
"""Extract operator type and class from a branch substep.
Args:
step: The substep configuration (dict, class, or instance)
Returns:
Tuple of (operator_type, operator_class)
"""
# Handle dict steps with keywords
if isinstance(step, dict):
type_keywords = [
'preprocessing', 'y_processing', 'feature_augmentation',
'sample_augmentation', 'concat_transform', 'model',
'meta_model', 'branch', 'merge', 'source_branch',
'merge_sources', 'name'
]
for kw in type_keywords:
if kw in step:
operator = step[kw]
if kw == 'name':
# For {'name': 'X', 'model': Y}, look for actual operator
if 'model' in step:
operator = step['model']
kw = 'model'
else:
continue
op_class = self._get_operator_class_name(operator)
return kw, op_class
# Check for 'class' key (serialized format)
if 'class' in step:
class_path = step['class']
if '.' in class_path:
op_class = class_path.rsplit('.', 1)[-1]
else:
op_class = class_path
return 'transform', op_class
return 'config', 'Config'
# Handle string (class path)
if isinstance(step, str):
if '.' in step:
op_class = step.rsplit('.', 1)[-1]
else:
op_class = step
return 'transform', op_class
# Handle class or instance
if isinstance(step, type):
return 'transform', step.__name__
elif hasattr(step, '__class__'):
return 'transform', type(step).__name__
return 'operator', str(type(step).__name__)
def _get_operator_class_name(self, operator: Any) -> str:
"""Get a human-readable class name from an operator.
Args:
operator: The operator (class, instance, string, or list)
Returns:
Human-readable class name string
"""
if operator is None:
return 'None'
if isinstance(operator, list):
if len(operator) == 0:
return 'Empty'
if len(operator) == 1:
return self._get_operator_class_name(operator[0])
# Multiple operators - join names
names = [self._get_operator_class_name(op) for op in operator[:3]]
suffix = f"... (+{len(operator)-3})" if len(operator) > 3 else ""
return ', '.join(names) + suffix
if isinstance(operator, str):
if '.' in operator:
return operator.rsplit('.', 1)[-1]
return operator
if isinstance(operator, dict):
if 'class' in operator:
class_path = operator['class']
if '.' in class_path:
return class_path.rsplit('.', 1)[-1]
return class_path
return 'Config'
if isinstance(operator, type):
return operator.__name__
return type(operator).__name__