"""
Execution Trace V3 - Records the exact path through pipeline that produced a prediction.
This module provides the core data structures for recording execution traces,
which enable deterministic prediction replay and pipeline extraction.
V3 improvements:
- OperatorChain tracking for complete execution path
- Per-branch and per-source artifact indexing
- Support for nested branches and multi-source pipelines
- Chain-based artifact lookup for deterministic replay
Key Classes:
- StepArtifacts: Artifacts produced by a single step with V3 indexes
- ExecutionStep: Record of a single step's execution with chain tracking
- ExecutionTrace: Complete trace of a pipeline execution path
Architecture:
During training, each step execution is recorded in the trace:
1. Step starts -> record step_index, operator info, input chain
2. Step completes -> record artifacts and output chains
3. Model produces prediction -> trace_id is attached to prediction
During prediction, the trace is used to:
1. Identify the minimal set of steps needed
2. Load the correct artifacts for each step via chain lookup
3. Execute only required steps via existing controllers
"""
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from uuid import uuid4
if TYPE_CHECKING:
from nirs4all.pipeline.storage.artifacts.operator_chain import OperatorChain
[docs]
class StepExecutionMode(str, Enum):
"""Mode of step execution.
Attributes:
TRAIN: Step fitted on data (creates new artifacts)
PREDICT: Step uses pre-fitted artifacts
SKIP: Step was skipped (no-op)
"""
TRAIN = "train"
PREDICT = "predict"
SKIP = "skip"
def __str__(self) -> str:
return self.value
[docs]
@dataclass
class StepArtifacts:
"""Artifacts produced by a single step (V3).
Records all artifacts created during step execution, with V3 indexes
for efficient lookup by chain path, branch, source, and fold.
Attributes:
artifact_ids: List of artifact IDs produced by this step
primary_artifact_id: Main artifact (e.g., model) if applicable
fold_artifact_ids: Per-fold artifacts for CV models
# V3 indexes
primary_artifacts: Map of chain_path to artifact_id for shared artifacts
by_branch: Artifacts indexed by branch path tuple
by_source: Artifacts indexed by source index
by_chain: Artifacts indexed by chain path
metadata: Additional artifact metadata (types, paths, etc.)
"""
artifact_ids: List[str] = field(default_factory=list)
primary_artifact_id: Optional[str] = None
fold_artifact_ids: Dict[int, str] = field(default_factory=dict)
# V3 indexes
primary_artifacts: Dict[str, str] = field(default_factory=dict)
by_branch: Dict[Tuple[int, ...], List[str]] = field(default_factory=dict)
by_source: Dict[int, List[str]] = field(default_factory=dict)
by_chain: Dict[str, str] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for YAML serialization.
Returns:
Dictionary suitable for manifest storage
"""
# Convert tuple keys to string for YAML compatibility
by_branch_serialized = {
".".join(str(b) for b in k): v
for k, v in self.by_branch.items()
}
return {
"artifact_ids": self.artifact_ids,
"primary_artifact_id": self.primary_artifact_id,
"fold_artifact_ids": self.fold_artifact_ids,
"primary_artifacts": self.primary_artifacts,
"by_branch": by_branch_serialized,
"by_source": self.by_source,
"by_chain": self.by_chain,
"metadata": self.metadata,
}
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StepArtifacts":
"""Create StepArtifacts from dictionary.
Args:
data: Dictionary from manifest
Returns:
StepArtifacts instance
"""
# Handle fold_artifact_ids with potential string keys from YAML
fold_artifacts = data.get("fold_artifact_ids", {})
if fold_artifacts:
fold_artifacts = {int(k): v for k, v in fold_artifacts.items()}
# Handle by_branch with string keys from YAML
by_branch_raw = data.get("by_branch", {})
by_branch: Dict[Tuple[int, ...], List[str]] = {}
for k, v in by_branch_raw.items():
if isinstance(k, str):
branch_tuple = tuple(int(b) for b in k.split(".")) if k else ()
else:
branch_tuple = tuple(k) if k else ()
by_branch[branch_tuple] = v
# Handle by_source with potential string keys
by_source_raw = data.get("by_source", {})
by_source = {int(k): v for k, v in by_source_raw.items()} if by_source_raw else {}
return cls(
artifact_ids=data.get("artifact_ids", []),
primary_artifact_id=data.get("primary_artifact_id"),
fold_artifact_ids=fold_artifacts,
primary_artifacts=data.get("primary_artifacts", {}),
by_branch=by_branch,
by_source=by_source,
by_chain=data.get("by_chain", {}),
metadata=data.get("metadata", {}),
)
[docs]
def add_artifact(
self,
artifact_id: str,
is_primary: bool = False,
chain_path: Optional[str] = None,
branch_path: Optional[List[int]] = None,
source_index: Optional[int] = None,
) -> None:
"""Add an artifact ID to this step's artifacts (V3).
Args:
artifact_id: The artifact ID to add
is_primary: Whether this is the primary artifact
chain_path: V3 operator chain path
branch_path: Branch path for indexing
source_index: Source index for multi-source indexing
"""
if artifact_id not in self.artifact_ids:
self.artifact_ids.append(artifact_id)
if is_primary:
self.primary_artifact_id = artifact_id
# V3 indexing
if chain_path:
self.by_chain[chain_path] = artifact_id
if is_primary:
self.primary_artifacts[chain_path] = artifact_id
if branch_path is not None:
branch_key = tuple(branch_path)
if branch_key not in self.by_branch:
self.by_branch[branch_key] = []
if artifact_id not in self.by_branch[branch_key]:
self.by_branch[branch_key].append(artifact_id)
if source_index is not None:
if source_index not in self.by_source:
self.by_source[source_index] = []
if artifact_id not in self.by_source[source_index]:
self.by_source[source_index].append(artifact_id)
[docs]
def add_fold_artifact(
self,
fold_id: int,
artifact_id: str,
chain_path: Optional[str] = None,
branch_path: Optional[List[int]] = None,
) -> None:
"""Add a fold-specific artifact.
Args:
fold_id: CV fold index
artifact_id: Artifact ID for this fold
chain_path: V3 operator chain path
branch_path: Branch path for indexing
"""
self.fold_artifact_ids[fold_id] = artifact_id
self.add_artifact(
artifact_id,
is_primary=False,
chain_path=chain_path,
branch_path=branch_path,
)
[docs]
def get_artifacts_for_branch(
self,
branch_path: List[int]
) -> List[str]:
"""Get artifact IDs matching a branch path.
Includes artifacts from:
- Exact branch match
- Empty branch (shared/pre-branch)
- Parent branches (for nested branches)
Args:
branch_path: Target branch path
Returns:
List of matching artifact IDs
"""
results: List[str] = []
target_tuple = tuple(branch_path)
for branch_key, ids in self.by_branch.items():
# Include if exact match, empty (shared), or prefix
if not branch_key: # Empty = shared
results.extend(ids)
elif branch_key == target_tuple:
results.extend(ids)
elif len(branch_key) < len(target_tuple) and target_tuple[:len(branch_key)] == branch_key:
# Parent branch
results.extend(ids)
# Deduplicate while preserving order
seen = set()
return [x for x in results if not (x in seen or seen.add(x))]
[docs]
def get_artifacts_for_source(self, source_index: int) -> List[str]:
"""Get artifact IDs for a specific source.
Args:
source_index: Source index to filter
Returns:
List of artifact IDs for that source
"""
return self.by_source.get(source_index, []).copy()
[docs]
def get_artifact_by_chain(self, chain_path: str) -> Optional[str]:
"""Get artifact ID by exact chain path match.
Args:
chain_path: Operator chain path
Returns:
Artifact ID or None if not found
"""
return self.by_chain.get(chain_path)
[docs]
def merge(self, other: "StepArtifacts") -> None:
"""Merge another StepArtifacts into this one.
Used when multiple substeps share the same step_index and their
artifacts need to be combined in the artifact_map.
Args:
other: StepArtifacts to merge into this one
"""
# Merge artifact_ids
for artifact_id in other.artifact_ids:
if artifact_id not in self.artifact_ids:
self.artifact_ids.append(artifact_id)
# Primary artifact: keep existing if set, otherwise use other
if not self.primary_artifact_id and other.primary_artifact_id:
self.primary_artifact_id = other.primary_artifact_id
# Merge fold_artifact_ids (other takes precedence for conflicts)
for fold_id, artifact_id in other.fold_artifact_ids.items():
if fold_id not in self.fold_artifact_ids:
self.fold_artifact_ids[fold_id] = artifact_id
# Merge primary_artifacts
for chain_path, artifact_id in other.primary_artifacts.items():
if chain_path not in self.primary_artifacts:
self.primary_artifacts[chain_path] = artifact_id
# Merge by_branch
for branch_key, ids in other.by_branch.items():
if branch_key not in self.by_branch:
self.by_branch[branch_key] = []
for artifact_id in ids:
if artifact_id not in self.by_branch[branch_key]:
self.by_branch[branch_key].append(artifact_id)
# Merge by_source
for source_idx, ids in other.by_source.items():
if source_idx not in self.by_source:
self.by_source[source_idx] = []
for artifact_id in ids:
if artifact_id not in self.by_source[source_idx]:
self.by_source[source_idx].append(artifact_id)
# Merge by_chain
for chain_path, artifact_id in other.by_chain.items():
if chain_path not in self.by_chain:
self.by_chain[chain_path] = artifact_id
# Merge metadata
self.metadata.update(other.metadata)
[docs]
@dataclass
class ExecutionStep:
"""Record of a single step's execution in the trace (V3).
Captures all information needed to replay this step during prediction,
including operator configuration, execution mode, and produced artifacts.
V3 additions:
- input_chain: Operator chain up to this step's input
- output_chains: Chains produced by this step (for branching)
- source_count: Number of X sources at this step
- produces_branches: Whether this is a branch operator
Attributes:
step_index: 1-based step number in the pipeline
operator_type: Type of operation (e.g., "transform", "model", "splitter")
operator_class: Class name of the operator (e.g., "PLSRegression", "SNV")
operator_config: Serialized operator configuration
execution_mode: How the step was executed (train/predict/skip)
artifacts: Artifacts produced by this step
branch_path: Branch indices if in a branch context
branch_name: Human-readable branch name
duration_ms: Execution duration in milliseconds
metadata: Additional step-specific metadata
# V3 chain tracking
input_chain_path: Serialized operator chain up to this step's input
output_chain_paths: List of chains produced by this step
source_count: Number of X sources processed
produces_branches: True if this is a branch operator
substep_index: Index within substep (for [model1, model2])
"""
step_index: int
operator_type: str = ""
operator_class: str = ""
operator_config: Dict[str, Any] = field(default_factory=dict)
execution_mode: StepExecutionMode = StepExecutionMode.TRAIN
artifacts: StepArtifacts = field(default_factory=StepArtifacts)
branch_path: List[int] = field(default_factory=list)
branch_name: str = ""
duration_ms: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
# V3 chain tracking
input_chain_path: str = ""
output_chain_paths: List[str] = field(default_factory=list)
source_count: int = 1
produces_branches: bool = False
substep_index: Optional[int] = None
# V4 shape tracking
# Input/output shapes are 2D layout (samples, features)
input_shape: Optional[Tuple[int, int]] = None
output_shape: Optional[Tuple[int, int]] = None
# Features shape is 3D per-source: List of (samples, processings, features) per source
input_features_shape: Optional[List[Tuple[int, int, int]]] = None
output_features_shape: Optional[List[Tuple[int, int, int]]] = None
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for YAML serialization.
Returns:
Dictionary suitable for manifest storage
"""
return {
"step_index": self.step_index,
"operator_type": self.operator_type,
"operator_class": self.operator_class,
"operator_config": self.operator_config,
"execution_mode": str(self.execution_mode),
"artifacts": self.artifacts.to_dict(),
"branch_path": self.branch_path,
"branch_name": self.branch_name,
"duration_ms": self.duration_ms,
"metadata": self.metadata,
"input_chain_path": self.input_chain_path,
"output_chain_paths": self.output_chain_paths,
"source_count": self.source_count,
"produces_branches": self.produces_branches,
"substep_index": self.substep_index,
"input_shape": list(self.input_shape) if self.input_shape else None,
"output_shape": list(self.output_shape) if self.output_shape else None,
"input_features_shape": [list(s) for s in self.input_features_shape] if self.input_features_shape else None,
"output_features_shape": [list(s) for s in self.output_features_shape] if self.output_features_shape else None,
}
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep":
"""Create ExecutionStep from dictionary.
Args:
data: Dictionary from manifest
Returns:
ExecutionStep instance
"""
# Handle execution_mode enum
mode_value = data.get("execution_mode", "train")
if isinstance(mode_value, str):
execution_mode = StepExecutionMode(mode_value)
else:
execution_mode = mode_value
# Handle artifacts
artifacts_data = data.get("artifacts", {})
if isinstance(artifacts_data, dict):
artifacts = StepArtifacts.from_dict(artifacts_data)
else:
artifacts = StepArtifacts()
# Parse shape fields
input_shape = data.get("input_shape")
output_shape = data.get("output_shape")
input_features_shape = data.get("input_features_shape")
output_features_shape = data.get("output_features_shape")
return cls(
step_index=data.get("step_index", 0),
operator_type=data.get("operator_type", ""),
operator_class=data.get("operator_class", ""),
operator_config=data.get("operator_config", {}),
execution_mode=execution_mode,
artifacts=artifacts,
branch_path=data.get("branch_path", []),
branch_name=data.get("branch_name", ""),
duration_ms=data.get("duration_ms", 0.0),
metadata=data.get("metadata", {}),
input_chain_path=data.get("input_chain_path", ""),
output_chain_paths=data.get("output_chain_paths", []),
source_count=data.get("source_count", 1),
produces_branches=data.get("produces_branches", False),
substep_index=data.get("substep_index"),
input_shape=tuple(input_shape) if input_shape else None,
output_shape=tuple(output_shape) if output_shape else None,
input_features_shape=[tuple(s) for s in input_features_shape] if input_features_shape else None,
output_features_shape=[tuple(s) for s in output_features_shape] if output_features_shape else None,
)
[docs]
def has_artifacts(self) -> bool:
"""Check if this step produced any artifacts.
Returns:
True if the step has at least one artifact
"""
return len(self.artifacts.artifact_ids) > 0
[docs]
def add_output_chain(self, chain_path: str) -> None:
"""Add an output chain path to this step.
Args:
chain_path: Operator chain path to add
"""
if chain_path and chain_path not in self.output_chain_paths:
self.output_chain_paths.append(chain_path)
[docs]
@dataclass
class ExecutionTrace:
"""Complete trace of a pipeline execution path.
Records the exact sequence of steps and artifacts that produced a prediction,
enabling deterministic replay for prediction, transfer, and export.
The trace is controller-agnostic: it records what happened without encoding
specific controller logic, so any controller (existing or custom) can be
replayed using the same infrastructure.
Attributes:
trace_id: Unique identifier for this trace
pipeline_uid: Parent pipeline UID
created_at: ISO timestamp of trace creation
steps: Ordered list of execution steps
model_step_index: Index of the model step that produced predictions
fold_weights: Per-fold weights for CV ensemble (None for single model)
preprocessing_chain: Summary of preprocessing steps for quick reference
metadata: Additional trace metadata (e.g., dataset info, run parameters)
"""
trace_id: str = field(default_factory=lambda: str(uuid4())[:12])
pipeline_uid: str = ""
created_at: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
steps: List[ExecutionStep] = field(default_factory=list)
model_step_index: Optional[int] = None
fold_weights: Optional[Dict[int, float]] = None
preprocessing_chain: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for YAML serialization.
Returns:
Dictionary suitable for manifest storage
"""
return {
"trace_id": self.trace_id,
"pipeline_uid": self.pipeline_uid,
"created_at": self.created_at,
"steps": [step.to_dict() for step in self.steps],
"model_step_index": self.model_step_index,
"fold_weights": self.fold_weights,
"preprocessing_chain": self.preprocessing_chain,
"metadata": self.metadata,
}
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionTrace":
"""Create ExecutionTrace from dictionary.
Args:
data: Dictionary from manifest
Returns:
ExecutionTrace instance
"""
steps = [
ExecutionStep.from_dict(step_data)
for step_data in data.get("steps", [])
]
# Handle fold_weights with potential string keys from YAML
fold_weights = data.get("fold_weights")
if fold_weights is not None:
fold_weights = {int(k): float(v) for k, v in fold_weights.items()}
return cls(
trace_id=data.get("trace_id", str(uuid4())[:12]),
pipeline_uid=data.get("pipeline_uid", ""),
created_at=data.get("created_at", ""),
steps=steps,
model_step_index=data.get("model_step_index"),
fold_weights=fold_weights,
preprocessing_chain=data.get("preprocessing_chain", ""),
metadata=data.get("metadata", {}),
)
[docs]
def add_step(self, step: ExecutionStep) -> None:
"""Add a step to the trace.
Args:
step: ExecutionStep to add
"""
self.steps.append(step)
[docs]
def get_step(self, step_index: int) -> Optional[ExecutionStep]:
"""Get a step by its index.
Args:
step_index: 1-based step index to find
Returns:
ExecutionStep or None if not found
"""
for step in self.steps:
if step.step_index == step_index:
return step
return None
[docs]
def get_steps_before(self, step_index: int) -> List[ExecutionStep]:
"""Get all steps before a given step index.
Args:
step_index: 1-based step index (exclusive)
Returns:
List of steps with step_index < given index
"""
return [s for s in self.steps if s.step_index < step_index]
[docs]
def get_steps_up_to_model(self) -> List[ExecutionStep]:
"""Get all steps up to and including the model step.
Returns:
List of steps needed to reproduce the prediction
"""
if self.model_step_index is None:
return self.steps.copy()
return [s for s in self.steps if s.step_index <= self.model_step_index]
[docs]
def get_artifact_ids(self) -> List[str]:
"""Get all artifact IDs in this trace.
Returns:
List of all artifact IDs across all steps
"""
artifact_ids = []
for step in self.steps:
artifact_ids.extend(step.artifacts.artifact_ids)
return artifact_ids
[docs]
def get_artifacts_by_step(self, step_index: int) -> Optional[StepArtifacts]:
"""Get artifacts for a specific step.
Args:
step_index: 1-based step index
Returns:
StepArtifacts or None if step not found
"""
step = self.get_step(step_index)
return step.artifacts if step else None
[docs]
def get_model_artifact_id(self) -> Optional[str]:
"""Get the primary model artifact ID.
Returns:
Model artifact ID or None if no model step
"""
if self.model_step_index is None:
return None
step = self.get_step(self.model_step_index)
if step and step.artifacts:
return step.artifacts.primary_artifact_id
return None
[docs]
def get_fold_artifact_ids(self) -> Dict[int, str]:
"""Get per-fold model artifact IDs.
Returns:
Dictionary of fold_id -> artifact_id
"""
if self.model_step_index is None:
return {}
step = self.get_step(self.model_step_index)
if step and step.artifacts:
return step.artifacts.fold_artifact_ids.copy()
return {}
[docs]
def set_model_step(
self,
step_index: int,
fold_weights: Optional[Dict[int, float]] = None
) -> None:
"""Set the model step index and optional fold weights.
Args:
step_index: Index of the model step
fold_weights: Optional per-fold weights for CV
"""
self.model_step_index = step_index
self.fold_weights = fold_weights
[docs]
def finalize(
self,
preprocessing_chain: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Finalize the trace with summary information.
Call this after all steps have been recorded to add summary info.
Args:
preprocessing_chain: Summary string of preprocessing (e.g., "SNV>SG>MinMax")
metadata: Additional metadata to merge
"""
if preprocessing_chain:
self.preprocessing_chain = preprocessing_chain
if metadata:
self.metadata.update(metadata)
def __repr__(self) -> str:
n_steps = len(self.steps)
n_artifacts = len(self.get_artifact_ids())
return (
f"ExecutionTrace(id={self.trace_id!r}, "
f"steps={n_steps}, artifacts={n_artifacts}, "
f"model_step={self.model_step_index})"
)