Source code for nirs4all.controllers.data.metadata_partitioner
"""
Metadata Partitioner Controller for metadata-based branching.
This controller partitions the dataset into multiple branches based on a
metadata column. Unlike copy branches (where all branches see all samples),
this controller creates non-overlapping sample sets - each sample exists in
exactly ONE branch.
For example, with column="site":
- Branch "site_A": Contains ONLY samples where metadata["site"] == "A"
- Branch "site_B": Contains ONLY samples where metadata["site"] == "B"
- Branch "site_C": Contains ONLY samples where metadata["site"] == "C"
This enables training separate models for different data subsets (e.g., per-site,
per-variety, per-instrument models) and combining their predictions via stacking.
Example:
>>> pipeline = [
... MinMaxScaler(),
... {
... "branch": [PLS(5), RF(100), XGB()],
... "by": "metadata_partitioner",
... "column": "site",
... "cv": ShuffleSplit(n_splits=3),
... "min_samples": 20, # Skip branches with < 20 samples
... },
... {"merge": "predictions"},
... Ridge(),
... ]
"""
import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional, Union, TYPE_CHECKING
import numpy as np
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.execution.result import StepOutput
logger = get_logger(__name__)
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.pipeline.steps.parser import ParsedStep
[docs]
@dataclass
class MetadataPartitionConfig:
"""Configuration for metadata partitioning.
Attributes:
column: Metadata column name to partition by.
branch_steps: Pipeline steps to execute in each branch.
cv: Cross-validation splitter for per-branch CV.
min_samples: Minimum samples required per branch. Branches with
fewer samples are skipped.
group_values: Optional dict mapping branch names to lists of values
to group together. E.g., {"others": ["C", "D", "E"]} groups
values C, D, E into a single "others" branch.
"""
column: str
branch_steps: List[Any]
cv: Optional[Any] = None
min_samples: int = 1
group_values: Optional[Dict[str, List[Any]]] = None
[docs]
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.column:
raise ValueError("column must be specified for metadata_partitioner")
if self.min_samples < 1:
raise ValueError(f"min_samples must be >= 1, got {self.min_samples}")
def _parse_metadata_partition_config(step: Dict[str, Any]) -> MetadataPartitionConfig:
"""Parse metadata partitioner configuration from step dict.
Args:
step: Step configuration dict with keys:
- branch: List of steps to run in each branch
- by: Must be "metadata_partitioner"
- column: Metadata column name
- cv: Optional CV splitter
- min_samples: Optional minimum samples per branch
- group_values: Optional value grouping dict
Returns:
MetadataPartitionConfig instance.
Raises:
ValueError: If required keys are missing or invalid.
"""
branch_def = step.get("branch", [])
# Handle case where branch is the full config dict
if isinstance(branch_def, dict):
column = branch_def.get("column") or step.get("column")
cv = branch_def.get("cv") or step.get("cv")
min_samples = branch_def.get("min_samples", step.get("min_samples", 1))
group_values = branch_def.get("group_values") or step.get("group_values")
branch_steps = branch_def.get("steps", [])
else:
# branch is the list of steps directly
column = step.get("column")
cv = step.get("cv")
min_samples = step.get("min_samples", 1)
group_values = step.get("group_values")
branch_steps = branch_def
if not column:
raise ValueError(
"metadata_partitioner requires 'column' parameter. "
"Specify the metadata column to partition by. "
"Example: {'branch': [...], 'by': 'metadata_partitioner', 'column': 'site'}"
)
return MetadataPartitionConfig(
column=column,
branch_steps=branch_steps,
cv=cv,
min_samples=min_samples,
group_values=group_values,
)
def _build_partition_groups(
unique_values: List[Any],
group_values: Optional[Dict[str, List[Any]]]
) -> Dict[str, List[Any]]:
"""Build partition groups from unique values and grouping config.
Args:
unique_values: List of unique values in the metadata column.
group_values: Optional dict mapping group names to value lists.
Returns:
Dict mapping partition names to lists of values in that partition.
"""
if group_values is None:
# Each unique value becomes its own partition
return {str(v): [v] for v in unique_values}
# Build grouped partitions
partitions = {}
grouped_values = set()
for group_name, values in group_values.items():
partitions[group_name] = values
grouped_values.update(values)
# Add remaining ungrouped values as individual partitions
for v in unique_values:
if v not in grouped_values:
partitions[str(v)] = [v]
return partitions
[docs]
@register_controller
class MetadataPartitionerController(OperatorController):
"""Controller for metadata-based branching via partitioning.
This controller creates branches by partitioning samples based on a
metadata column. Each branch contains a disjoint subset of samples
where the metadata column equals specific value(s).
Key behaviors:
- Each branch contains a disjoint subset of samples
- Per-branch cross-validation is supported
- Branches with too few samples can be skipped (min_samples)
- Values can be grouped into combined branches (group_values)
- Models train and predict only on their partition
Attributes:
priority: Controller priority (set to 3 to run before other controllers).
"""
priority = 3 # High priority to catch this branch type early
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Check if the step matches the metadata_partitioner branch pattern.
Matches:
{"branch": [...], "by": "metadata_partitioner", "column": "..."}
Args:
step: Original step configuration
operator: Deserialized operator
keyword: Step keyword
Returns:
True if this is a metadata_partitioner branch definition.
"""
if keyword != "branch":
return False
if isinstance(step, dict):
# Check for 'by' key at step level
if step.get("by") == "metadata_partitioner":
return True
# Check for 'by' key inside branch dict
branch_def = step.get("branch", {})
if isinstance(branch_def, dict) and branch_def.get("by") == "metadata_partitioner":
return True
return False
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Metadata partitioner operates on dataset level."""
return True
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Metadata partitioner should execute in prediction mode.
In prediction mode, we need to route samples to the correct
branch based on their metadata value.
"""
return True
[docs]
def execute(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute the metadata partitioner branch step.
Creates branches based on metadata column values, with each branch
containing only samples matching specific value(s).
In prediction mode, samples are routed to the correct branch based
on their metadata value. Each sample is processed by the branch
that matches its metadata value.
Args:
step_info: Parsed step containing branch definitions
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
source: Data source index
mode: Execution mode ("train" or "predict")
loaded_binaries: Pre-loaded binary objects for prediction mode
prediction_store: External prediction store for model predictions
Returns:
Tuple of (updated_context, StepOutput with collected artifacts)
"""
# Parse configuration
config = _parse_metadata_partition_config(step_info.original_step)
# In prediction/explain mode, use sample routing logic
if mode in ("predict", "explain"):
return self._execute_prediction_mode(
step_info=step_info,
dataset=dataset,
context=context,
runtime_context=runtime_context,
config=config,
source=source,
mode=mode,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
)
logger.info(f"Creating metadata partitioner branches by column '{config.column}'")
# Store initial context as snapshot
initial_context = context.copy()
initial_processing = copy.deepcopy(context.selector.processing)
# Snapshot dataset features
initial_features = self._snapshot_features(dataset)
# Get metadata column values
metadata = dataset.metadata
if metadata is None or config.column not in metadata.columns:
available_cols = list(metadata.columns) if metadata is not None else []
raise ValueError(
f"Metadata column '{config.column}' not found. "
f"Available columns: {available_cols}"
)
column_values = metadata[config.column].values
# Get training sample indices
train_context = context.with_partition("train")
train_selector = train_context.selector.copy()
train_selector.include_augmented = False
train_sample_indices = dataset._indexer.x_indices(
train_selector, include_augmented=False, include_excluded=False
)
if len(train_sample_indices) == 0:
logger.warning("No training samples found, skipping metadata partitioner")
return context, StepOutput()
# Get unique values and build partition groups
# Use all samples for determining unique values (not just train)
unique_values = sorted(set(column_values))
partition_groups = _build_partition_groups(unique_values, config.group_values)
logger.info(f" Found {len(unique_values)} unique values in '{config.column}'")
logger.info(f" Creating {len(partition_groups)} partition(s)")
# V3: Start branch step recording in trace
recorder = runtime_context.trace_recorder
if recorder is not None:
recorder.start_branch_step(
step_index=runtime_context.step_number,
branch_count=len(partition_groups),
operator_config={
"by": "metadata_partitioner",
"column": config.column,
"partitions": list(partition_groups.keys()),
},
)
# In predict/explain mode, filter to target branch if specified
target_branch_id = None
target_branch_name = None
if mode in ("predict", "explain") and hasattr(runtime_context, 'target_model') and runtime_context.target_model:
target_branch_id = runtime_context.target_model.get("branch_id")
target_branch_name = runtime_context.target_model.get("branch_name")
# Create branch contexts
branch_contexts: List[Dict[str, Any]] = []
all_artifacts = []
skipped_branches = []
branch_id = 0
for partition_name, partition_values in partition_groups.items():
# Skip if not target branch in predict mode
if target_branch_id is not None and branch_id != target_branch_id:
branch_id += 1
continue
# Find sample indices for this partition
partition_mask = np.isin(column_values, partition_values)
partition_indices = np.where(partition_mask)[0]
# Filter to training samples only for min_samples check
train_partition_indices = np.intersect1d(partition_indices, train_sample_indices)
n_train_samples = len(train_partition_indices)
if n_train_samples < config.min_samples:
logger.warning(
f" Skipping partition '{partition_name}': {n_train_samples} samples "
f"< min_samples={config.min_samples}"
)
skipped_branches.append({
"name": partition_name,
"values": partition_values,
"n_samples": n_train_samples,
"reason": "min_samples",
})
branch_id += 1
continue
logger.info(
f" Partition '{partition_name}': {n_train_samples} train samples "
f"(values: {partition_values})"
)
# V3: Enter branch context in trace recorder
if recorder is not None:
recorder.enter_branch(branch_id)
# Restore dataset features to initial state for this branch
self._restore_features(dataset, initial_features)
# Create isolated context for this branch
branch_context = initial_context.copy()
# Build branch_path by appending to parent's branch_path
parent_branch_path = context.selector.branch_path or []
new_branch_path = parent_branch_path + [branch_id]
branch_context.selector = branch_context.selector.with_branch(
branch_id=branch_id,
branch_name=partition_name,
branch_path=new_branch_path
)
# Reset processing to initial state for this branch
branch_context.selector.processing = copy.deepcopy(initial_processing)
# Store metadata partition info
branch_context.custom["metadata_partition"] = {
"sample_indices": partition_indices.tolist(),
"train_sample_indices": train_partition_indices.tolist(),
"partition_value": partition_name,
"partition_values": partition_values,
"column": config.column,
"n_samples": len(partition_indices),
"n_train_samples": n_train_samples,
}
# Store CV splitter if provided (for per-branch CV)
if config.cv is not None:
branch_context.custom["per_branch_cv"] = config.cv
# Reset artifact load counter for this branch
if runtime_context:
runtime_context.artifact_load_counter = {}
# In predict/explain mode, load branch-specific binaries
branch_binaries = loaded_binaries
if mode in ("predict", "explain") and runtime_context.artifact_loader:
branch_binaries = runtime_context.artifact_loader.get_step_binaries(
runtime_context.step_number, branch_id=branch_id
)
if not branch_binaries:
branch_binaries = loaded_binaries
# Apply per-branch CV if specified (before executing branch steps)
# This creates independent folds within this partition
if config.cv is not None and mode == "train":
cv_step = {"split": config.cv}
if runtime_context.step_runner:
logger.debug(f" Applying per-branch CV for partition '{partition_name}'")
runtime_context.substep_number = -1 # Mark as CV step
# Record CV substep in trace
if recorder is not None:
recorder.start_branch_substep(
parent_step_index=runtime_context.step_number,
branch_id=branch_id,
operator_type="split",
operator_class=config.cv.__class__.__name__,
substep_index=-1, # Special index for CV
branch_name=partition_name,
)
cv_result = runtime_context.step_runner.execute(
step=cv_step,
dataset=dataset,
context=branch_context,
runtime_context=runtime_context,
loaded_binaries=branch_binaries,
prediction_store=prediction_store
)
if recorder is not None:
recorder.end_step(is_model=False)
branch_context = cv_result.updated_context
all_artifacts.extend(cv_result.artifacts)
# Execute branch steps sequentially
for substep_idx, substep in enumerate(config.branch_steps):
if runtime_context.step_runner:
runtime_context.substep_number = substep_idx
# Record substep in trace before execution
if recorder is not None:
op_type, op_class = self._extract_substep_info(substep)
recorder.start_branch_substep(
parent_step_index=runtime_context.step_number,
branch_id=branch_id,
operator_type=op_type,
operator_class=op_class,
substep_index=substep_idx,
branch_name=partition_name,
)
result = runtime_context.step_runner.execute(
step=substep,
dataset=dataset,
context=branch_context,
runtime_context=runtime_context,
loaded_binaries=branch_binaries,
prediction_store=prediction_store
)
# End substep recording
if recorder is not None:
is_model = op_type in ("model", "meta_model")
recorder.end_step(is_model=is_model)
branch_context = result.updated_context
all_artifacts.extend(result.artifacts)
# Snapshot features AFTER branch processing completes
branch_features_snapshot = self._snapshot_features(dataset)
# V3: Snapshot chain state before exiting branch
branch_chain_snapshot = recorder.current_chain() if recorder else None
# V3: Exit branch context in trace recorder
if recorder is not None:
recorder.exit_branch()
# Store the final context for this branch
branch_contexts.append({
"branch_id": branch_id,
"name": partition_name,
"context": branch_context,
"partition_info": {
"values": partition_values,
"n_samples": len(partition_indices),
"n_train_samples": n_train_samples,
"sample_indices": partition_indices.tolist(),
"train_sample_indices": train_partition_indices.tolist(),
},
"features_snapshot": branch_features_snapshot,
"chain_snapshot": branch_chain_snapshot,
})
logger.success(f" Partition '{partition_name}' (branch {branch_id}) completed")
branch_id += 1
# V3: End branch step in trace
if recorder is not None:
recorder.end_step()
# Handle nested branching (multiply with existing branches)
existing_branches = context.custom.get("branch_contexts", [])
if existing_branches:
new_branch_contexts = self._multiply_branch_contexts(
existing_branches, branch_contexts
)
else:
new_branch_contexts = branch_contexts
# Update result context
result_context = context.copy()
result_context.custom["branch_contexts"] = new_branch_contexts
result_context.custom["in_branch_mode"] = True
result_context.custom["metadata_partitioner_active"] = True
result_context.custom["metadata_partitioner_config"] = {
"column": config.column,
"group_values": config.group_values,
"min_samples": config.min_samples,
}
# Build metadata
metadata_info = {
"branch_count": len(new_branch_contexts),
"metadata_partitioner": True,
"column": config.column,
"partitions": [bc["name"] for bc in branch_contexts],
"skipped_branches": skipped_branches,
}
logger.success(
f"Metadata partitioner completed with {len(new_branch_contexts)} branch(es)"
+ (f" ({len(skipped_branches)} skipped)" if skipped_branches else "")
)
return result_context, StepOutput(
artifacts=all_artifacts,
metadata=metadata_info
)
def _snapshot_features(self, dataset: "SpectroDataset") -> List[Any]:
"""Create a deep copy of dataset features for branch isolation."""
return copy.deepcopy(dataset._features.sources)
def _restore_features(
self,
dataset: "SpectroDataset",
snapshot: List[Any]
) -> None:
"""Restore dataset features from snapshot."""
dataset._features.sources = copy.deepcopy(snapshot)
def _multiply_branch_contexts(
self,
existing: List[Dict[str, Any]],
new: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Multiply existing branch contexts with new ones for nested branching.
Creates Cartesian product: each existing branch × each new branch.
Args:
existing: List of existing branch context dicts
new: List of new branch context dicts
Returns:
Combined list of branch contexts
"""
result = []
flattened_id = 0
for parent in existing:
parent_id = parent["branch_id"]
parent_name = parent["name"]
parent_context = parent["context"]
parent_branch_path = parent_context.selector.branch_path or [parent_id]
for child in new:
child_id = child["branch_id"]
child_name = child["name"]
child_context = child["context"]
# Create combined context
combined_context = child_context.copy()
# Build nested branch_path: parent_path + child_id
combined_branch_path = parent_branch_path + [child_id]
combined_context.selector = combined_context.selector.with_branch(
branch_id=flattened_id,
branch_name=f"{parent_name}_{child_name}",
branch_path=combined_branch_path
)
result.append({
"branch_id": flattened_id,
"name": f"{parent_name}_{child_name}",
"context": combined_context,
"parent_branch_id": parent_id,
"child_branch_id": child_id,
"branch_path": combined_branch_path,
"partition_info": child.get("partition_info", {}),
"features_snapshot": child.get("features_snapshot"),
"chain_snapshot": child.get("chain_snapshot"),
})
flattened_id += 1
return result
def _extract_substep_info(self, step: Any) -> Tuple[str, str]:
"""Extract operator type and class from a branch substep.
Args:
step: The substep configuration (dict, class, or instance)
Returns:
Tuple of (operator_type, operator_class)
"""
# Handle dict steps with keywords
if isinstance(step, dict):
type_keywords = [
'preprocessing', 'y_processing', 'feature_augmentation',
'sample_augmentation', 'concat_transform', 'model',
'meta_model', 'branch', 'merge', 'source_branch',
'merge_sources', 'name'
]
for kw in type_keywords:
if kw in step:
operator = step[kw]
if kw == 'name':
# For {'name': 'X', 'model': Y}, look for actual operator
if 'model' in step:
operator = step['model']
kw = 'model'
else:
continue
op_class = self._get_operator_class_name(operator)
return kw, op_class
# Check for 'class' key (serialized format)
if 'class' in step:
class_path = step['class']
if '.' in class_path:
op_class = class_path.rsplit('.', 1)[-1]
else:
op_class = class_path
return 'transform', op_class
return 'config', 'Config'
# Handle string (class path)
if isinstance(step, str):
if '.' in step:
op_class = step.rsplit('.', 1)[-1]
else:
op_class = step
return 'transform', op_class
# Handle class or instance
if isinstance(step, type):
return 'transform', step.__name__
elif hasattr(step, '__class__'):
return 'transform', type(step).__name__
return 'operator', str(type(step).__name__)
def _get_operator_class_name(self, operator: Any) -> str:
"""Get a human-readable class name from an operator.
Args:
operator: The operator (class, instance, string, or list)
Returns:
Human-readable class name string
"""
if operator is None:
return 'None'
if isinstance(operator, list):
if len(operator) == 0:
return 'Empty'
if len(operator) == 1:
return self._get_operator_class_name(operator[0])
# Multiple operators - join names
names = [self._get_operator_class_name(op) for op in operator[:3]]
suffix = f"... (+{len(operator)-3})" if len(operator) > 3 else ""
return ', '.join(names) + suffix
if isinstance(operator, str):
if '.' in operator:
return operator.rsplit('.', 1)[-1]
return operator
if isinstance(operator, dict):
if 'class' in operator:
class_path = operator['class']
if '.' in class_path:
return class_path.rsplit('.', 1)[-1]
return class_path
return 'Config'
if isinstance(operator, type):
return operator.__name__
return type(operator).__name__
# =========================================================================
# Phase 4: Prediction Mode Sample Routing
# =========================================================================
def _execute_prediction_mode(
self,
step_info: "ParsedStep",
dataset: "SpectroDataset",
context: "ExecutionContext",
runtime_context: "RuntimeContext",
config: MetadataPartitionConfig,
source: int = -1,
mode: str = "predict",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None,
) -> Tuple["ExecutionContext", StepOutput]:
"""Execute metadata partitioner in prediction mode with sample routing.
In prediction mode, each sample is routed to the correct branch based
on its metadata value. The appropriate branch's transformers and models
are applied to each sample subset.
Routing Algorithm:
1. Read metadata column values for all prediction samples
2. Group samples by their partition (using same grouping as training)
3. For each partition, apply branch-specific transformers
4. Combine results back in original sample order
Args:
step_info: Parsed step containing branch definitions
dataset: Dataset to operate on
context: Pipeline execution context
runtime_context: Runtime infrastructure context
config: Parsed partition configuration
source: Data source index
mode: Execution mode ("predict" or "explain")
loaded_binaries: Pre-loaded binary objects for this step
prediction_store: External prediction store
Returns:
Tuple of (updated_context, StepOutput)
Raises:
ValueError: If metadata column is missing or samples cannot be routed.
"""
logger.info(
f"Metadata partitioner (predict mode): routing by column '{config.column}'"
)
# Get metadata column values
metadata = dataset.metadata
if metadata is None:
raise ValueError(
f"Dataset has no metadata. Cannot route samples by column '{config.column}' "
f"in prediction mode. Ensure prediction data includes the same metadata "
f"column used during training."
)
if config.column not in metadata.columns:
available_cols = list(metadata.columns)
raise ValueError(
f"Metadata column '{config.column}' not found in prediction data. "
f"Available columns: {available_cols}. "
f"Ensure prediction data includes the same metadata column used during training."
)
column_values = metadata[config.column].values
n_samples = len(column_values)
# Get unique values and build partition groups (same logic as training)
unique_values = sorted(set(column_values))
partition_groups = _build_partition_groups(unique_values, config.group_values)
logger.info(
f" Found {len(unique_values)} unique values, "
f"mapping to {len(partition_groups)} partition(s)"
)
# Load partition routing info from trace/manifest if available
stored_partitions = self._load_partition_routing_info(runtime_context)
# Validate partitions match training if stored info is available
if stored_partitions:
training_partitions = set(stored_partitions.keys())
prediction_partitions = set(partition_groups.keys())
# Check for unknown partitions in prediction data
unknown_partitions = prediction_partitions - training_partitions
if unknown_partitions:
unknown_samples = []
for partition_name in unknown_partitions:
partition_values = partition_groups[partition_name]
mask = np.isin(column_values, partition_values)
unknown_samples.extend(np.where(mask)[0].tolist())
logger.warning(
f" {len(unknown_samples)} samples have metadata values "
f"not seen during training: {unknown_partitions}. "
f"These samples will use fallback routing."
)
# Store initial context and features
initial_context = context.copy()
initial_processing = copy.deepcopy(context.selector.processing)
initial_features = self._snapshot_features(dataset)
# V3: Start branch step recording in trace
recorder = runtime_context.trace_recorder
if recorder is not None:
recorder.start_branch_step(
step_index=runtime_context.step_number,
branch_count=len(partition_groups),
operator_config={
"by": "metadata_partitioner",
"column": config.column,
"partitions": list(partition_groups.keys()),
"prediction_mode": True,
},
)
# Process each partition
branch_contexts: List[Dict[str, Any]] = []
all_artifacts = []
processed_samples: Dict[int, int] = {} # sample_idx -> branch_id
branch_id = 0
for partition_name, partition_values in partition_groups.items():
# Find sample indices for this partition
partition_mask = np.isin(column_values, partition_values)
partition_indices = np.where(partition_mask)[0]
if len(partition_indices) == 0:
logger.debug(f" Partition '{partition_name}': no samples in prediction data")
branch_id += 1
continue
logger.info(
f" Partition '{partition_name}': {len(partition_indices)} samples"
)
# Track which samples go to which branch
for idx in partition_indices:
processed_samples[idx] = branch_id
# V3: Enter branch context in trace recorder
if recorder is not None:
recorder.enter_branch(branch_id)
# Restore dataset features to initial state for this branch
self._restore_features(dataset, initial_features)
# Create isolated context for this branch
branch_context = initial_context.copy()
# Build branch_path
parent_branch_path = context.selector.branch_path or []
new_branch_path = parent_branch_path + [branch_id]
branch_context.selector = branch_context.selector.with_branch(
branch_id=branch_id,
branch_name=partition_name,
branch_path=new_branch_path
)
# Reset processing to initial state
branch_context.selector.processing = copy.deepcopy(initial_processing)
# Store metadata partition info for downstream controllers
branch_context.custom["metadata_partition"] = {
"sample_indices": partition_indices.tolist(),
"train_sample_indices": partition_indices.tolist(), # Same in predict mode
"partition_value": partition_name,
"partition_values": partition_values,
"column": config.column,
"n_samples": len(partition_indices),
"n_train_samples": len(partition_indices),
"prediction_mode": True,
}
# Reset artifact load counter for this branch
if runtime_context:
runtime_context.artifact_load_counter = {}
# Load branch-specific artifacts
branch_binaries = loaded_binaries
if runtime_context.artifact_loader:
branch_binaries = runtime_context.artifact_loader.get_step_binaries(
runtime_context.step_number, branch_id=branch_id
)
if not branch_binaries:
branch_binaries = loaded_binaries
# Execute branch steps
for substep_idx, substep in enumerate(config.branch_steps):
if runtime_context.step_runner:
runtime_context.substep_number = substep_idx
# Record substep in trace
if recorder is not None:
op_type, op_class = self._extract_substep_info(substep)
recorder.start_branch_substep(
parent_step_index=runtime_context.step_number,
branch_id=branch_id,
operator_type=op_type,
operator_class=op_class,
substep_index=substep_idx,
branch_name=partition_name,
)
result = runtime_context.step_runner.execute(
step=substep,
dataset=dataset,
context=branch_context,
runtime_context=runtime_context,
loaded_binaries=branch_binaries,
prediction_store=prediction_store
)
if recorder is not None:
is_model = op_type in ("model", "meta_model")
recorder.end_step(is_model=is_model)
branch_context = result.updated_context
all_artifacts.extend(result.artifacts)
# Snapshot features after branch processing
branch_features_snapshot = self._snapshot_features(dataset)
# V3: Exit branch
if recorder is not None:
recorder.exit_branch()
# Store branch context
branch_contexts.append({
"branch_id": branch_id,
"name": partition_name,
"context": branch_context,
"partition_info": {
"values": partition_values,
"n_samples": len(partition_indices),
"sample_indices": partition_indices.tolist(),
},
"features_snapshot": branch_features_snapshot,
})
logger.success(
f" Partition '{partition_name}' (branch {branch_id}) completed"
)
branch_id += 1
# V3: End branch step
if recorder is not None:
recorder.end_step()
# Check for unprocessed samples (missing partitions)
unprocessed = [i for i in range(n_samples) if i not in processed_samples]
if unprocessed:
logger.warning(
f" {len(unprocessed)} samples were not processed "
f"(no matching partition). Sample indices: {unprocessed[:10]}..."
)
# Update result context
result_context = context.copy()
result_context.custom["branch_contexts"] = branch_contexts
result_context.custom["in_branch_mode"] = True
result_context.custom["metadata_partitioner_active"] = True
result_context.custom["metadata_partitioner_config"] = {
"column": config.column,
"group_values": config.group_values,
"min_samples": config.min_samples,
}
result_context.custom["sample_routing"] = {
"processed_samples": processed_samples,
"n_total_samples": n_samples,
"n_processed": len(processed_samples),
"n_unprocessed": len(unprocessed),
}
# Build metadata
metadata_info = {
"branch_count": len(branch_contexts),
"metadata_partitioner": True,
"prediction_mode": True,
"column": config.column,
"partitions": [bc["name"] for bc in branch_contexts],
"sample_routing": {
"n_total": n_samples,
"n_processed": len(processed_samples),
"n_unprocessed": len(unprocessed),
},
}
logger.success(
f"Metadata partitioner (predict mode) completed: "
f"{len(branch_contexts)} branch(es), "
f"{len(processed_samples)}/{n_samples} samples routed"
)
return result_context, StepOutput(
artifacts=all_artifacts,
metadata=metadata_info
)
def _load_partition_routing_info(
self,
runtime_context: "RuntimeContext"
) -> Optional[Dict[str, Dict[str, Any]]]:
"""Load partition routing info from training trace/manifest.
Attempts to load the partition configuration used during training
to validate prediction data and handle missing partitions.
Args:
runtime_context: Runtime context with artifact loader
Returns:
Dict mapping partition names to their configuration, or None
if not available.
"""
if not runtime_context:
return None
# Try to get from trace
if hasattr(runtime_context, 'trace') and runtime_context.trace:
trace = runtime_context.trace
step_idx = runtime_context.step_number
step = trace.get_step(step_idx)
if step and step.metadata:
partition_info = step.metadata.get("partitions")
if partition_info:
return {name: {"name": name} for name in partition_info}
# Try to get from artifact loader metadata
if hasattr(runtime_context, 'artifact_loader') and runtime_context.artifact_loader:
loader = runtime_context.artifact_loader
if hasattr(loader, 'get_step_metadata'):
step_meta = loader.get_step_metadata(runtime_context.step_number)
if step_meta and "partitions" in step_meta:
partitions = step_meta["partitions"]
return {name: {"name": name} for name in partitions}
return None