"""Pipeline executor for executing a single pipeline on a single dataset."""
import hashlib
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from nirs4all.data.dataset import SpectroDataset
from nirs4all.data.predictions import Predictions
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.config.context import ExecutionContext
from nirs4all.pipeline.storage.manifest_manager import ManifestManager
from nirs4all.pipeline.steps.step_runner import StepRunner
from nirs4all.pipeline.trace import TraceRecorder
logger = get_logger(__name__)
[docs]
class PipelineExecutor:
"""Executes a single pipeline configuration on a single dataset.
Handles:
- Step-by-step execution
- Context propagation
- Artifact management for one pipeline run
- Predictions accumulation for this pipeline
Attributes:
step_runner: Executes individual steps
manifest_manager: Manages pipeline manifests
verbose: Verbosity level
mode: Execution mode (train/predict/explain)
continue_on_error: Whether to continue on step failures
artifact_registry: Registry for v2 artifact management
"""
def __init__(
self,
step_runner: StepRunner,
manifest_manager: Optional[ManifestManager] = None,
verbose: int = 0,
mode: str = "train",
continue_on_error: bool = False,
saver: Any = None,
artifact_loader: Any = None,
artifact_registry: Any = None
):
"""Initialize pipeline executor.
Args:
step_runner: Step runner for executing individual steps
manifest_manager: Optional manifest manager
verbose: Verbosity level
mode: Execution mode (train/predict/explain)
continue_on_error: Whether to continue on step failures
saver: Simulation saver for file operations
artifact_loader: Artifact loader for predict/explain modes
artifact_registry: Artifact registry for v2 artifact management
"""
self.step_runner = step_runner
self.manifest_manager = manifest_manager
self.verbose = verbose
self.mode = mode
self.continue_on_error = continue_on_error
self.saver = saver
self.artifact_loader = artifact_loader
self.artifact_registry = artifact_registry
# Execution state
self.step_number = 0
self.substep_number = -1
self.operation_count = 0
[docs]
def initialize_context(self, dataset: SpectroDataset) -> ExecutionContext:
"""Initialize ExecutionContext for pipeline execution.
Args:
dataset: Dataset to create context for
Returns:
Initialized ExecutionContext
"""
from nirs4all.pipeline.config.context import DataSelector, PipelineState, StepMetadata
selector = DataSelector(
partition=None,
processing=[["raw"]] * dataset.features_sources(),
layout="2d",
concat_source=True
)
state = PipelineState(
y_processing="numeric",
step_number=0,
mode=self.mode
)
metadata = StepMetadata()
# Get aggregate setting from dataset for propagation through pipeline
aggregate_column = dataset.aggregate
return ExecutionContext(
selector=selector,
state=state,
metadata=metadata,
aggregate_column=aggregate_column
)
[docs]
def execute(
self,
steps: List[Any],
config_name: str,
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any, # RuntimeContext
prediction_store: Optional[Predictions] = None,
generator_choices: Optional[List[Dict[str, Any]]] = None
) -> None:
"""Execute pipeline steps sequentially on dataset.
Args:
steps: List of pipeline steps to execute
config_name: Pipeline configuration name
dataset: Dataset to process
context: Initial execution context
runtime_context: Runtime infrastructure context
prediction_store: Prediction store for accumulating results
generator_choices: List of generator choices that produced this pipeline
Raises:
RuntimeError: If pipeline execution fails
"""
# Reset state for this execution
self.step_number = 0
self.substep_number = -1
self.operation_count = 0
logger.starting(f"Starting pipeline {config_name} on dataset {dataset.name}")
# Compute pipeline hash for identification
pipeline_hash = self._compute_pipeline_hash(steps)
# Create pipeline in manifest system (if in train mode)
pipeline_uid = None
if self.mode == "train" and self.manifest_manager:
pipeline_config = {"steps": steps}
pipeline_uid, pipeline_dir = self.manifest_manager.create_pipeline(
name=config_name,
dataset=dataset.name,
pipeline_config=pipeline_config,
pipeline_hash=pipeline_hash,
generator_choices=generator_choices
)
# Register with saver
if self.saver:
self.saver.register(pipeline_uid)
# Set pipeline_uid on runtime_context
if runtime_context:
runtime_context.pipeline_uid = pipeline_uid
else:
# For predict/explain modes, use temporary UID
pipeline_uid = f"temp_{pipeline_hash}"
# Save pipeline configuration
if self.mode != "predict" and self.mode != "explain" and self.saver:
self.saver.save_json("pipeline.json", steps)
# Initialize prediction store if not provided
if prediction_store is None:
prediction_store = Predictions()
# Initialize trace recorder for execution trace recording (Phase 2)
trace_recorder = None
if self.mode == "train" and runtime_context:
trace_recorder = TraceRecorder(
pipeline_uid=pipeline_uid or "",
metadata={"dataset": dataset.name, "config_name": config_name}
)
runtime_context.trace_recorder = trace_recorder
# Execute all steps
all_artifacts = []
try:
context = self._execute_steps(
steps,
dataset,
context,
runtime_context,
prediction_store,
all_artifacts
)
# Save final pipeline configuration
if self.mode != "predict" and self.mode != "explain" and self.saver:
self.saver.save_json("pipeline.json", steps)
# Finalize and save execution trace
if trace_recorder is not None and self.manifest_manager and pipeline_uid:
trace = trace_recorder.finalize(
preprocessing_chain=dataset.short_preprocessings_str(),
metadata={"n_steps": len(steps), "n_artifacts": len(all_artifacts)}
)
self.manifest_manager.save_execution_trace(pipeline_uid, trace)
# Print best result if predictions were generated
if prediction_store.num_predictions > 0:
# Use None for ascending to let ranker infer from metric
pipeline_best = prediction_store.get_best(
ascending=None
)
if pipeline_best:
logger.success(f"Pipeline Best: {Predictions.pred_short_string(pipeline_best)}")
logger.debug(
f"Pipeline {config_name} completed successfully "
f"on dataset {dataset.name}"
)
except Exception as e:
logger.error(
f"Pipeline {config_name} on dataset {dataset.name} "
f"failed: {str(e)}"
)
import traceback
traceback.print_exc()
raise
def _execute_steps(
self,
steps: List[Any],
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any,
prediction_store: Predictions,
all_artifacts: List[Any]
) -> ExecutionContext:
"""Execute all steps in sequence.
Handles pipeline branching: when a branch step is encountered, subsequent
steps are executed on each branch context independently.
Args:
steps: List of steps to execute
dataset: Dataset to process
context: Current execution context
runtime_context: Runtime infrastructure context
prediction_store: Prediction store
all_artifacts: List to accumulate artifacts
Returns:
Updated execution context
"""
from nirs4all.pipeline.execution.result import ArtifactMeta
for step in steps:
self.step_number += 1
self.substep_number = 0
self.operation_count = 0
# Sync step number to runtime_context
if runtime_context:
runtime_context.step_number = self.step_number
runtime_context.substep_number = self.substep_number
runtime_context.operation_count = self.operation_count
runtime_context.reset_processing_counter() # Reset for unique artifact IDs within step
# Update context with current step number
if isinstance(context, ExecutionContext):
context = context.with_step_number(self.step_number)
# Load binaries if in prediction/explain mode
loaded_binaries = None
if self.mode in ("predict", "explain") and self.artifact_loader:
loaded_binaries = self.artifact_loader.get_step_binaries(self.step_number)
if self.verbose > 1 and loaded_binaries:
print(f"🔍 Loaded {', '.join(b[0] for b in loaded_binaries)} binaries for step {self.step_number}")
# Check if we're in branch mode and this is NOT a branch step
branch_contexts = context.custom.get("branch_contexts", [])
is_branch_step = isinstance(step, dict) and "branch" in step
# Merge steps need access to all branch contexts, so they execute globally
is_merge_step = isinstance(step, dict) and "merge" in step
if branch_contexts and not is_branch_step and not is_merge_step:
# Execute step on each branch context
context = self._execute_step_on_branches(
step=step,
dataset=dataset,
context=context,
runtime_context=runtime_context,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
all_artifacts=all_artifacts
)
else:
# Normal execution (single context)
context = self._execute_single_step(
step=step,
dataset=dataset,
context=context,
runtime_context=runtime_context,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
all_artifacts=all_artifacts
)
return context
def _execute_step_on_branches(
self,
step: Any,
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any,
loaded_binaries: Optional[List] = None,
prediction_store: Optional[Predictions] = None,
all_artifacts: Optional[List] = None
) -> ExecutionContext:
"""Execute a step on all branch contexts.
Args:
step: Step to execute
dataset: Dataset to process
context: Context containing branch_contexts in custom dict
runtime_context: Runtime infrastructure context
loaded_binaries: Pre-loaded binaries for predict mode
prediction_store: Prediction store
all_artifacts: List to accumulate artifacts
Returns:
Updated context with updated branch contexts
"""
from nirs4all.pipeline.execution.result import ArtifactMeta
branch_contexts = context.custom.get("branch_contexts", [])
if not branch_contexts:
# No branches, execute normally
return self._execute_single_step(
step, dataset, context, runtime_context,
loaded_binaries, prediction_store, all_artifacts
)
logger.debug(f"Executing step on {len(branch_contexts)} branch(es)")
updated_branch_contexts = []
for branch_info in branch_contexts:
branch_id = branch_info["branch_id"]
branch_name = branch_info["name"]
branch_context = branch_info["context"]
# Restore dataset features from branch snapshot if available
# This ensures each branch's post-branch steps (like model) use the correct
# feature data that was produced by that branch's preprocessing steps
features_snapshot = branch_info.get("features_snapshot")
if features_snapshot is not None:
import copy
dataset._features.sources = copy.deepcopy(features_snapshot)
# V3: Restore chain state from branch snapshot if available
# This ensures each branch's post-branch steps use the correct operator chain
# for artifact ID generation (fixes MetaModel chain_path issues across branches)
chain_snapshot = branch_info.get("chain_snapshot")
if chain_snapshot is not None and runtime_context and runtime_context.trace_recorder:
runtime_context.trace_recorder.reset_chain_to(chain_snapshot)
logger.debug(f"Branch {branch_id} ({branch_name})")
# Update step number on branch context
branch_context = branch_context.with_step_number(self.step_number)
# For predict mode, load binaries specifically for this branch
branch_binaries = None
if self.mode in ("predict", "explain"):
# Get the full branch_path from context (handles nested branches)
branch_path = getattr(branch_context.selector, 'branch_path', None)
# First try artifact_provider (for minimal pipeline)
if runtime_context and hasattr(runtime_context, 'artifact_provider') and runtime_context.artifact_provider:
# Use artifact_provider for minimal pipeline prediction
# Try to get branch-specific artifacts
branch_binaries = runtime_context.artifact_provider.get_artifacts_for_step(
self.step_number, branch_path=branch_path, branch_id=branch_id
)
if not branch_binaries:
# Try without branch qualifier (may be a non-branch step artifact)
branch_binaries = runtime_context.artifact_provider.get_artifacts_for_step(self.step_number)
elif self.artifact_loader:
# Fallback to artifact_loader for traditional prediction
if branch_path:
# Use full branch_path for proper nested branch matching
branch_binaries = self.artifact_loader.get_step_binaries(
self.step_number, branch_path=branch_path
)
else:
# Fallback to simple branch_id
branch_binaries = self.artifact_loader.get_step_binaries(
self.step_number, branch_id=branch_id
)
if not branch_binaries:
# Fallback to non-branch binaries if no branch-specific ones exist
branch_binaries = loaded_binaries
# Extract operator info for trace recording
operator_type, operator_class, operator_config = self._extract_step_info(step)
# Get branch_path from context
branch_path = getattr(branch_context.selector, 'branch_path', [])
branch_name_ctx = getattr(branch_context.selector, 'branch_name', '') or ''
# Record step start in execution trace for this branch
if runtime_context:
runtime_context.record_step_start(
step_index=self.step_number,
operator_type=operator_type,
operator_class=operator_class,
operator_config=operator_config,
branch_path=branch_path,
branch_name=branch_name_ctx or branch_name,
mode=self.mode
)
# Record input shapes for branch step (parity with non-branch execution)
self._record_dataset_shapes(
dataset,
branch_context,
runtime_context,
is_input=True,
)
# Execute step on this branch
try:
step_result = self.step_runner.execute(
step=step,
dataset=dataset,
context=branch_context,
runtime_context=runtime_context,
loaded_binaries=branch_binaries,
prediction_store=prediction_store
)
# Record output shapes after execution for branch steps
if runtime_context:
self._record_dataset_shapes(
dataset,
step_result.updated_context,
runtime_context,
is_input=False,
)
# Record step end in execution trace
if runtime_context:
is_model = operator_type in ("model", "meta_model")
runtime_context.record_step_end(is_model=is_model)
# Process artifacts
processed_artifacts = self._process_step_artifacts(
step_result.artifacts,
branch_id=branch_id,
branch_name=branch_name
)
if all_artifacts is not None:
all_artifacts.extend(processed_artifacts)
# Append artifacts to manifest
if (self.mode == "train" and
self.manifest_manager and
runtime_context.pipeline_uid and
processed_artifacts):
self.manifest_manager.append_artifacts(
runtime_context.pipeline_uid,
processed_artifacts
)
# Update branch context
updated_branch_contexts.append({
"branch_id": branch_id,
"name": branch_name,
"context": step_result.updated_context,
# Preserve any additional metadata
**{k: v for k, v in branch_info.items()
if k not in ("branch_id", "name", "context")}
})
except Exception as e:
# Record step end even on failure
if runtime_context:
runtime_context.record_step_end(skip_trace=True)
if self.continue_on_error:
logger.warning(f"Branch {branch_id} step {self.step_number} failed: {str(e)}")
# Keep original context on failure
updated_branch_contexts.append(branch_info)
else:
raise RuntimeError(
f"Pipeline step {self.step_number} failed on branch {branch_id}: {str(e)}"
) from e
# Update context with new branch contexts
result_context = context.copy()
result_context.custom["branch_contexts"] = updated_branch_contexts
# Sync operation_count back from runtime_context
if runtime_context:
self.operation_count = runtime_context.operation_count
return result_context
def _execute_single_step(
self,
step: Any,
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any,
loaded_binaries: Optional[List] = None,
prediction_store: Optional[Predictions] = None,
all_artifacts: Optional[List] = None
) -> ExecutionContext:
"""Execute a single step (non-branched).
Args:
step: Step to execute
dataset: Dataset to process
context: Current execution context
runtime_context: Runtime infrastructure context
loaded_binaries: Pre-loaded binaries for predict mode
prediction_store: Prediction store
all_artifacts: List to accumulate artifacts
Returns:
Updated context
"""
from nirs4all.pipeline.execution.result import ArtifactMeta
# Extract operator info for trace recording
operator_type, operator_class, operator_config = self._extract_step_info(step)
# Record step start in execution trace
if runtime_context:
branch_path = getattr(context.selector, 'branch_path', [])
branch_name = getattr(context.selector, 'branch_name', '') or ''
runtime_context.record_step_start(
step_index=self.step_number,
operator_type=operator_type,
operator_class=operator_class,
operator_config=operator_config,
branch_path=branch_path,
branch_name=branch_name,
mode=self.mode
)
# Record input shapes before execution
self._record_dataset_shapes(dataset, context, runtime_context, is_input=True)
try:
# Execute step via step runner
step_result = self.step_runner.execute(
step=step,
dataset=dataset,
context=context,
runtime_context=runtime_context,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store
)
logger.debug(f"Step {self.step_number} completed with {len(step_result.artifacts)} artifacts")
if self.verbose > 1:
logger.debug(str(dataset))
# Record output shapes after execution
if runtime_context:
self._record_dataset_shapes(dataset, step_result.updated_context, runtime_context, is_input=False)
# Process artifacts (persist if needed)
processed_artifacts = self._process_step_artifacts(step_result.artifacts)
if all_artifacts is not None:
all_artifacts.extend(processed_artifacts)
# Process outputs (save files)
for output in step_result.outputs:
if isinstance(output, dict):
# Legacy: already saved
pass
elif isinstance(output, tuple) and len(output) >= 3:
# New: (data, name, type)
data, name, type_hint = output
if self.saver:
self.saver.save_output(
step_number=self.step_number,
name=name,
data=data,
extension=f".{type_hint}" if not type_hint.startswith('.') else type_hint
)
# Update context
context = step_result.updated_context
# Sync operation_count back from runtime_context
if runtime_context:
self.operation_count = runtime_context.operation_count
# Append artifacts to manifest if in train mode
if (self.mode == "train" and
self.manifest_manager and
runtime_context.pipeline_uid and
processed_artifacts):
self.manifest_manager.append_artifacts(
runtime_context.pipeline_uid,
processed_artifacts
)
logger.debug(f"Appended {len(processed_artifacts)} artifacts to manifest")
# Record step end in execution trace
if runtime_context:
# Determine if this is a model step (check for model operator type)
is_model = operator_type in ("model", "meta_model")
runtime_context.record_step_end(is_model=is_model)
except Exception as e:
# Record step end even on failure
if runtime_context:
runtime_context.record_step_end(skip_trace=True)
if self.continue_on_error:
logger.warning(f"Step {self.step_number} failed but continuing: {str(e)}")
else:
raise RuntimeError(f"Pipeline step {self.step_number} failed: {str(e)}") from e
return context
def _process_step_artifacts(
self,
artifacts: List[Any],
branch_id: Optional[int] = None,
branch_name: Optional[str] = None
) -> List[Any]:
"""Process and persist step artifacts.
Args:
artifacts: Raw artifacts from step execution
branch_id: Optional branch ID for artifact naming
branch_name: Optional branch name for metadata
Returns:
List of processed artifact metadata
"""
from nirs4all.pipeline.execution.result import ArtifactMeta
from nirs4all.pipeline.storage.artifacts.types import ArtifactRecord
processed_artifacts = []
for artifact in artifacts:
if isinstance(artifact, ArtifactRecord):
# v2 system: ArtifactRecord from registry.register()
# Convert to dict for manifest storage
processed_artifacts.append(artifact.to_dict())
elif isinstance(artifact, (ArtifactMeta, dict)):
# Legacy: already persisted
meta = artifact
# Add branch metadata if applicable
if branch_id is not None and isinstance(meta, dict):
meta = dict(meta) # Copy to avoid mutation
meta["branch_id"] = branch_id
meta["branch_name"] = branch_name
processed_artifacts.append(meta)
elif isinstance(artifact, tuple) and len(artifact) >= 2:
# New: (obj, name, format_hint)
obj, name = artifact[0], artifact[1]
format_hint = artifact[2] if len(artifact) > 2 else None
# Add branch prefix to name if branching
if branch_id is not None:
name = f"{name}_b{branch_id}"
if self.saver:
meta = self.saver.persist_artifact(
step_number=self.step_number,
name=name,
obj=obj,
format_hint=format_hint,
branch_id=branch_id,
branch_name=branch_name
)
processed_artifacts.append(meta)
return processed_artifacts
def _filter_binaries_for_branch(
self,
loaded_binaries: List,
branch_id: int
) -> List:
"""Filter loaded binaries for a specific branch.
Filters artifacts by branch_id metadata. Artifacts without branch_id
(pre-branch/shared) are included for all branches.
Args:
loaded_binaries: All loaded binaries for this step as (name, obj) tuples
branch_id: Target branch ID
Returns:
Filtered list of binaries for this branch (including shared artifacts)
"""
if not loaded_binaries:
return loaded_binaries
# Note: loaded_binaries are (name, obj) tuples from ArtifactLoader
# The ArtifactLoader now handles branch filtering internally via get_step_binaries(step, branch_id)
# This method is kept for backward compatibility but primarily relies on ArtifactLoader
return loaded_binaries
def _extract_step_info(self, step: Any) -> tuple:
"""Extract operator information from a step for trace recording.
Args:
step: Pipeline step configuration
Returns:
Tuple of (operator_type, operator_class, operator_config)
"""
operator_type = ""
operator_class = ""
operator_config = {}
if isinstance(step, dict):
# Common step keys for operator type detection
type_keywords = {
"model": "model",
"meta_model": "meta_model",
"transform": "transform",
"y_processing": "y_processing",
"feature_augmentation": "feature_augmentation",
"concat_transform": "concat_transform",
"sample_partitioner": "sample_partitioner",
"sample_augmentation": "sample_augmentation",
"resampler": "resampler",
"feature_selection": "feature_selection",
"outlier_excluder": "outlier_excluder",
"sample_filter": "sample_filter",
"splitter": "splitter",
"branch": "branch",
"merge": "merge",
"source_branch": "source_branch",
"merge_sources": "merge_sources",
"preprocessing": "preprocessing",
}
for key, op_type in type_keywords.items():
if key in step:
operator_type = op_type
operator_value = step[key]
# Extract meaningful class name(s) from the operator value
operator_class = self._extract_operator_class_name(operator_value, key)
# Store sanitized config (avoid storing large objects)
try:
import json
json.dumps(step, default=str) # Test if serializable
operator_config = step
except (TypeError, ValueError):
operator_config = {"_type": str(type(step))}
break
else:
# Check for serialized class format: {'class': 'module.ClassName', 'params': {...}}
if 'class' in step:
class_path = step['class']
if isinstance(class_path, str) and '.' in class_path:
operator_class = class_path.split('.')[-1]
# Infer type from module path
class_path_lower = class_path.lower()
if 'model_selection' in class_path_lower or 'fold' in operator_class.lower() or 'split' in operator_class.lower():
operator_type = "splitter"
elif 'cross_decomposition' in class_path_lower or 'linear_model' in class_path_lower or 'ensemble' in class_path_lower:
operator_type = "model"
elif 'preprocessing' in class_path_lower or 'scaler' in operator_class.lower():
operator_type = "transform"
elif 'feature_selection' in class_path_lower:
operator_type = "feature_selection"
else:
operator_type = "operator"
operator_config = step
else:
operator_type = "config"
operator_class = str(class_path) if class_path else "config"
else:
# Dict without recognized keyword - just record type as dict
operator_type = "config"
operator_class = str(type(step).__name__)
elif hasattr(step, '__class__') and not isinstance(step, (str, int, float, bool, type(None))):
# Raw class instance (e.g., sklearn transformer, cross-validator)
class_name = step.__class__.__name__
operator_class = class_name
# Infer operator type from module or class patterns
module = step.__class__.__module__
if 'cross_decomposition' in module or 'linear_model' in module or 'ensemble' in module:
operator_type = "model"
elif 'model_selection' in module or 'Fold' in class_name or 'Split' in class_name:
operator_type = "splitter"
elif 'preprocessing' in module or 'Scaler' in class_name:
operator_type = "transform"
elif 'feature_selection' in module:
operator_type = "feature_selection"
else:
operator_type = "operator"
elif isinstance(step, str):
# String step - could be a serialized class path like 'sklearn.preprocessing._data.MinMaxScaler'
# or a command like 'chart_2d'
if '.' in step:
# Likely a fully qualified class path
operator_class = step.split('.')[-1] # Extract just the class name
# Infer operator type from the module path
step_lower = step.lower()
if 'cross_decomposition' in step_lower or 'linear_model' in step_lower or 'ensemble' in step_lower:
operator_type = "model"
elif 'model_selection' in step_lower or 'fold' in operator_class.lower() or 'split' in operator_class.lower():
operator_type = "splitter"
elif 'preprocessing' in step_lower or 'scaler' in operator_class.lower():
operator_type = "transform"
elif 'feature_selection' in step_lower:
operator_type = "feature_selection"
elif 'nirs4all.operators.transforms' in step:
operator_type = "transform"
else:
operator_type = "operator"
else:
# Simple command name (e.g., 'chart_2d')
operator_type = "command"
operator_class = step
return operator_type, operator_class, operator_config
def _extract_operator_class_name(self, value: Any, keyword: str = "") -> str:
"""Extract a meaningful class name from an operator value.
Handles various cases:
- Direct class instances (e.g., MinMaxScaler())
- Class references (e.g., MinMaxScaler)
- Lists of operators (e.g., [SNV(), FirstDerivative()])
- Dicts with operator configuration
- Strings (e.g., class paths)
Args:
value: The operator value to extract name from
keyword: The keyword context (e.g., 'model', 'feature_augmentation')
Returns:
Human-readable operator class name
"""
# Handle None
if value is None:
return "None"
# Handle string values
if isinstance(value, str):
# For strings like 'sklearn.preprocessing.MinMaxScaler', extract last part
return value.split('.')[-1] if '.' in value else value
# Handle class references (not instances)
if isinstance(value, type):
return value.__name__
# Handle lists of operators
if isinstance(value, (list, tuple)):
if len(value) == 0:
return "[]"
# Extract names from first few items
names = []
for item in value[:3]: # Limit to first 3 for readability
name = self._extract_operator_class_name(item, keyword)
if name and name not in ('dict', 'list', 'tuple'):
names.append(name)
if names:
result = ", ".join(names)
if len(value) > 3:
result += f" (+{len(value) - 3})"
return result
return f"[{len(value)} items]"
# Handle dicts with class specification
if isinstance(value, dict):
# Check for common class keys
for class_key in ('class', 'function', 'type', 'model', 'operator'):
if class_key in value:
class_val = value[class_key]
if isinstance(class_val, str):
return class_val.split('.')[-1] if '.' in class_val else class_val
elif hasattr(class_val, '__name__'):
return class_val.__name__
elif hasattr(class_val, '__class__'):
return class_val.__class__.__name__
# For sample_augmentation dicts, show transformer count
if 'transformers' in value:
transformers = value['transformers']
count = value.get('count', 1)
if isinstance(transformers, (list, tuple)):
names = [self._extract_operator_class_name(t) for t in transformers[:2]]
names = [n for n in names if n not in ('dict', 'list')]
if names:
return f"{', '.join(names)} ×{count}"
return f"[{len(transformers)} aug] ×{count}"
# For concat_transform with operations
if 'operations' in value:
ops = value['operations']
return self._extract_operator_class_name(ops, keyword)
return "config"
# Handle class instances with __class__ attribute
if hasattr(value, '__class__'):
class_name = value.__class__.__name__
# Skip generic Python types
if class_name not in ('dict', 'list', 'tuple', 'set', 'str', 'int', 'float', 'bool', 'NoneType'):
return class_name
return str(type(value).__name__)
def _record_dataset_shapes(
self,
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any,
is_input: bool = True
) -> None:
"""Record dataset shapes to the execution trace.
Captures both 2D layout shape and 3D per-source feature shapes.
Args:
dataset: The dataset to measure
context: Execution context with selector
runtime_context: Runtime context with trace recorder
is_input: True to record input shapes, False for output shapes
"""
try:
# Get 2D layout shape (samples × features)
X_2d = dataset.x(context.selector, layout="2d", include_excluded=False)
if isinstance(X_2d, list):
# Multi-source with concat
layout_shape = (X_2d[0].shape[0], sum(x.shape[1] for x in X_2d))
else:
layout_shape = X_2d.shape
# Get 3D per-source shapes (samples × processings × features)
X_3d = dataset.x(context.selector, layout="3d", concat_source=False, include_excluded=False)
if not isinstance(X_3d, list):
X_3d = [X_3d]
features_shapes = [x.shape for x in X_3d]
# Record to trace
if is_input:
runtime_context.record_input_shapes(
input_shape=layout_shape,
features_shape=features_shapes
)
else:
runtime_context.record_output_shapes(
output_shape=layout_shape,
features_shape=features_shapes
)
except Exception:
# Shape recording is non-critical, don't fail the step
pass
def _compute_pipeline_hash(self, steps: List[Any]) -> str:
"""Compute MD5 hash of pipeline configuration.
Args:
steps: Pipeline steps
Returns:
6-character hash string
"""
pipeline_json = json.dumps(steps, sort_keys=True, default=str).encode('utf-8')
return hashlib.md5(pipeline_json).hexdigest()[:6]
[docs]
def next_op(self) -> int:
"""Get the next operation ID (for compatibility)."""
self.operation_count += 1
return self.operation_count
[docs]
def execute_minimal(
self,
steps: List[Any],
minimal_pipeline: Any, # MinimalPipeline
dataset: SpectroDataset,
context: ExecutionContext,
runtime_context: Any, # RuntimeContext
prediction_store: Optional[Predictions] = None
) -> None:
"""Execute minimal pipeline for prediction.
This method executes only the steps from a MinimalPipeline, which
represents the subset of the full pipeline needed to replay a prediction.
It's the key optimization of Phase 5: instead of replaying the entire
original pipeline, we only run the required steps.
The method:
1. Uses the minimal pipeline's step list (not full pipeline)
2. Injects artifacts via the artifact_provider in runtime_context
3. Runs controllers in predict mode
4. Skips steps not in the minimal pipeline
Args:
steps: List of step configs (from minimal_pipeline.steps[i].step_config)
minimal_pipeline: MinimalPipeline with artifact mappings
dataset: Dataset to process
context: Execution context
runtime_context: Runtime context with artifact_provider
prediction_store: Optional prediction store
Note:
The artifact_provider in runtime_context should be a MinimalArtifactProvider
that provides artifacts by step index from the MinimalPipeline.
"""
from nirs4all.pipeline.execution.result import ArtifactMeta
logger.info(f"Executing minimal pipeline: {len(steps)} steps")
# Reset state
self.step_number = 0
self.substep_number = -1
self.operation_count = 0
if prediction_store is None:
prediction_store = Predictions()
# Get step indices from minimal pipeline for validation
minimal_step_indices = set()
if hasattr(minimal_pipeline, 'get_step_indices'):
minimal_step_indices = set(minimal_pipeline.get_step_indices())
# Get target branch_path from the model step
# All steps need to use this branch for filtering artifacts
target_branch_path = None
if minimal_pipeline and hasattr(minimal_pipeline, 'model_step_index'):
model_idx = minimal_pipeline.model_step_index
if model_idx and hasattr(minimal_pipeline, 'get_step'):
model_step = minimal_pipeline.get_step(model_idx)
if model_step and model_step.branch_path:
target_branch_path = model_step.branch_path
# Track previous step_index to avoid resetting counters for substeps with same step_index
prev_step_idx = None
# Execute each step using original step indices from MinimalPipeline
# This ensures step_number matches training-time indices for artifact lookups
for list_idx, step in enumerate(steps):
# Get original step index from minimal pipeline
if hasattr(minimal_pipeline, 'steps') and list_idx < len(minimal_pipeline.steps):
step_idx = minimal_pipeline.steps[list_idx].step_index
else:
step_idx = list_idx + 1 # Fallback to 1-based enumeration
if step is None:
logger.debug(f"Step {step_idx}: skipped (no config)")
continue
self.step_number = step_idx
self.substep_number = 0
self.operation_count = 0
# Sync to runtime_context
if runtime_context:
runtime_context.step_number = self.step_number
runtime_context.substep_number = self.substep_number
runtime_context.operation_count = self.operation_count
# Only reset counters when step_index changes (not for substeps with same step_index)
if step_idx != prev_step_idx:
runtime_context.reset_processing_counter()
prev_step_idx = step_idx
# Update context with step number
context = context.with_step_number(self.step_number)
# For branch steps, don't pre-filter artifacts by branch_path
# The branch controller needs all artifacts, and internal transformers
# will filter by their branch context when looking up artifacts
is_branch_step = isinstance(step, dict) and "branch" in step
is_merge_step = isinstance(step, dict) and "merge" in step
branch_path = None if is_branch_step else target_branch_path
# Get substep_index from minimal pipeline step for artifact filtering
substep_index = None
if hasattr(minimal_pipeline, 'steps') and list_idx < len(minimal_pipeline.steps):
substep_index = minimal_pipeline.steps[list_idx].substep_index
# Get binaries from artifact_provider instead of artifact_loader
loaded_binaries = None
if runtime_context and runtime_context.artifact_provider:
# Use artifact_provider for minimal pipeline prediction
# Pass branch_path to filter artifacts for multi-branch pipelines
# (except for branch steps which need all artifacts)
# Pass substep_index to filter artifacts for branch substeps
artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
step_idx, branch_path=branch_path, substep_index=substep_index
)
if artifacts:
loaded_binaries = artifacts # Already in (name, obj) format
logger.debug(f"Loaded {len(artifacts)} artifact(s) for step {step_idx} (substep={substep_index})")
elif self.mode in ("predict", "explain") and self.artifact_loader:
# Fallback to artifact_loader
loaded_binaries = self.artifact_loader.get_step_binaries(self.step_number)
# Check for branch contexts (already computed is_branch_step above)
branch_contexts = context.custom.get("branch_contexts", [])
# Execute on branches for post-branch steps, but not for:
# - branch steps (they create branches)
# - merge steps (they consume branches and exit branch mode)
if branch_contexts and not is_branch_step and not is_merge_step:
# Execute on each branch
context = self._execute_step_on_branches(
step=step,
dataset=dataset,
context=context,
runtime_context=runtime_context,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
all_artifacts=[]
)
else:
# Single context execution
context = self._execute_single_step(
step=step,
dataset=dataset,
context=context,
runtime_context=runtime_context,
loaded_binaries=loaded_binaries,
prediction_store=prediction_store,
all_artifacts=[]
)
logger.success(f"Minimal pipeline completed: {prediction_store.num_predictions} predictions")