"""
Context classes for pipeline execution.
This module provides typed context components that replace the Dict[str, Any] context
pattern used throughout the pipeline system. It separates three distinct concerns:
1. DataSelector: Immutable data selection parameters for dataset.x() and dataset.y()
2. PipelineState: Mutable pipeline state that evolves through transformations
3. StepMetadata: Metadata for controller coordination and step tracking
4. ExecutionContext: Composite context with custom data extensibility
5. ArtifactProvider: Interface for providing artifacts during prediction replay
The separation enables:
- Type safety throughout the codebase
- Clear interfaces between components
- Better testability
- Explicit controller communication
- Future extensibility via custom dict
Example:
>>> selector = DataSelector(partition="train", processing=[["raw"]])
>>> state = PipelineState(y_processing="numeric")
>>> metadata = StepMetadata(keyword="transform")
>>> context = ExecutionContext(selector=selector, state=state, metadata=metadata)
>>> new_context = context.with_partition("test")
"""
from dataclasses import dataclass, field, replace as dataclass_replace, fields
from typing import Any, Dict, List, Optional, Iterator, Protocol, Tuple, Union
from copy import deepcopy
from collections.abc import MutableMapping
from abc import ABC, abstractmethod
[docs]
@dataclass
class DataSelector(MutableMapping):
"""
Mutable data selection parameters for dataset operations.
This class replaces the dict-based Selector pattern used by dataset.x() and dataset.y().
It implements the MutableMapping protocol, so it can be used as a dictionary.
It supports arbitrary keys via an internal dict to allow flexibility.
Attributes:
partition: Data partition to select ("train", "test", "all", "val")
processing: List of processing chains (one per data source)
layout: Data layout for X retrieval ("2d", "3d", "4d")
concat_source: Whether to concatenate multiple sources
fold_id: Optional fold identifier for cross-validation
include_augmented: Whether to include augmented samples
y: Optional target processing version (e.g. "numeric", "scaled")
branch_id: Optional branch identifier for pipeline branching (0-indexed)
DEPRECATED: Use branch_path instead for nested branch support.
branch_path: List of branch indices for nested branching (e.g., [0, 2] for
branch 2 inside branch 0). Empty list means pre-branch/shared artifacts.
branch_name: Optional human-readable branch name for tracking
Example:
>>> selector = DataSelector(partition="train", processing=[["raw"]])
>>> selector["y"] = "scaled" # Direct modification
>>> selector["custom_key"] = "value" # Arbitrary keys supported
>>> print(selector["partition"])
"""
partition: str = "all"
processing: List[List[str]] = field(default_factory=lambda: [["raw"]])
layout: str = "2d"
concat_source: bool = True
fold_id: Optional[int] = None
include_augmented: bool = False
y: Optional[str] = None
branch_id: Optional[int] = None
branch_path: List[int] = field(default_factory=list)
branch_name: Optional[str] = None
_extra: Dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
def __iter__(self) -> Iterator[str]:
"""Iterate over non-None fields and extra keys."""
# Yield defined fields if they are not None
for f in fields(self):
if f.name == "_extra":
continue
if getattr(self, f.name) is not None:
yield f.name
# Yield extra keys
yield from self._extra
[docs]
def __getitem__(self, key: str) -> Any:
"""Get field value or extra key."""
# Check if it's a defined field
if hasattr(self, key) and key != "_extra":
val = getattr(self, key)
if val is not None:
return val
# Check extra keys
if key in self._extra:
return self._extra[key]
raise KeyError(key)
[docs]
def __setitem__(self, key: str, value: Any) -> None:
"""Set field value or extra key."""
if hasattr(self, key) and key != "_extra":
setattr(self, key, value)
else:
self._extra[key] = value
[docs]
def __delitem__(self, key: str) -> None:
"""Delete extra key or set field to None."""
if hasattr(self, key) and key != "_extra":
setattr(self, key, None)
elif key in self._extra:
del self._extra[key]
else:
raise KeyError(key)
[docs]
def __len__(self) -> int:
"""Count of non-None fields and extra keys."""
return sum(1 for _ in self)
[docs]
def copy(self) -> "DataSelector":
"""Create a deep copy of the selector."""
new_selector = DataSelector(
partition=self.partition,
processing=deepcopy(self.processing),
layout=self.layout,
concat_source=self.concat_source,
fold_id=self.fold_id,
include_augmented=self.include_augmented,
y=self.y,
branch_id=self.branch_id,
branch_path=list(self.branch_path),
branch_name=self.branch_name
)
new_selector._extra = deepcopy(self._extra)
return new_selector
[docs]
def with_partition(self, partition: str) -> "DataSelector":
"""
Create new selector with updated partition.
Args:
partition: New partition value
Returns:
New DataSelector with updated partition
"""
new_selector = self.copy()
new_selector.partition = partition
return new_selector
[docs]
def with_processing(self, processing: List[List[str]]) -> "DataSelector":
"""
Create new selector with updated processing chains.
Args:
processing: New processing chains
Returns:
New DataSelector with updated processing
"""
new_selector = self.copy()
new_selector.processing = processing
return new_selector
[docs]
def with_layout(self, layout: str) -> "DataSelector":
"""
Create new selector with updated layout.
Args:
layout: New layout value
Returns:
New DataSelector with updated layout
"""
new_selector = self.copy()
new_selector.layout = layout
return new_selector
[docs]
def with_fold(self, fold_id: Optional[int]) -> "DataSelector":
"""
Create new selector with updated fold_id.
Args:
fold_id: New fold identifier
Returns:
New DataSelector with updated fold_id
"""
new_selector = self.copy()
new_selector.fold_id = fold_id
return new_selector
[docs]
def with_augmented(self, include_augmented: bool) -> "DataSelector":
"""
Create new selector with updated include_augmented flag.
Args:
include_augmented: Whether to include augmented samples
Returns:
New DataSelector with updated include_augmented
"""
new_selector = self.copy()
new_selector.include_augmented = include_augmented
return new_selector
[docs]
def with_branch(
self,
branch_id: Optional[int] = None,
branch_name: Optional[str] = None,
branch_path: Optional[List[int]] = None
) -> "DataSelector":
"""
Create new selector with updated branch information.
Args:
branch_id: Branch identifier (0-indexed). DEPRECATED: Use branch_path.
branch_name: Human-readable branch name
branch_path: List of branch indices for nested branching
Returns:
New DataSelector with updated branch info
"""
new_selector = self.copy()
new_selector.branch_id = branch_id
new_selector.branch_name = branch_name
# Support both old branch_id and new branch_path
if branch_path is not None:
new_selector.branch_path = list(branch_path)
elif branch_id is not None:
# Convert single branch_id to branch_path for compatibility
new_selector.branch_path = [branch_id]
else:
new_selector.branch_path = []
return new_selector
[docs]
@dataclass
class PipelineState:
"""
Mutable pipeline state that evolves through execution.
This class tracks state that changes as the pipeline executes:
- Y transformation state (e.g., "encoded_LabelEncoder_001")
- Current step number in execution
Unlike DataSelector, this is mutable because state must evolve.
Attributes:
y_processing: Current y transformation identifier
step_number: Current step number (1-indexed)
mode: Execution mode ("train", "predict", "explain")
Example:
>>> state = PipelineState(y_processing="numeric")
>>> state.step_number = 2 # Mutable update
>>> state.y_processing = "encoded_LabelEncoder_001"
"""
y_processing: str = "numeric"
step_number: int = 0
mode: str = "train"
[docs]
def copy(self) -> "PipelineState":
"""
Create a deep copy of this state.
Returns:
Deep copy of PipelineState
"""
return PipelineState(
y_processing=self.y_processing,
step_number=self.step_number,
mode=self.mode
)
[docs]
class ArtifactProvider(ABC):
"""Abstract interface for providing artifacts during prediction replay.
The ArtifactProvider enables controller-agnostic artifact injection:
controllers request artifacts by step index rather than by name matching,
which is deterministic and works with any controller type.
This interface is used during prediction mode to provide pre-loaded
artifacts (transformers, models, etc.) to controllers without requiring
them to know about the artifact storage system.
Implementations:
- MapArtifactProvider: In-memory dictionary-based provider
- LoaderArtifactProvider: Wraps ArtifactLoader for lazy loading
Example:
>>> provider = MapArtifactProvider(artifact_map)
>>> artifacts = provider.get_artifacts_for_step(step_index=2)
>>> for artifact_id, obj in artifacts:
... process(obj)
"""
[docs]
@abstractmethod
def get_artifact(self, step_index: int, fold_id: Optional[int] = None) -> Optional[Any]:
"""Get a single artifact for a step.
Args:
step_index: 1-based step index
fold_id: Optional fold ID for fold-specific artifacts
Returns:
Artifact object or None if not found
"""
pass
[docs]
@abstractmethod
def get_artifacts_for_step(
self,
step_index: int,
branch_path: Optional[List[int]] = None,
branch_id: Optional[int] = None,
source_index: Optional[int] = None,
substep_index: Optional[int] = None
) -> List[Tuple[str, Any]]:
"""Get all artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter
branch_id: Optional branch ID filter
source_index: Optional source/dataset index filter for multi-source
substep_index: Optional substep index filter for branch substeps
Returns:
List of (artifact_id, artifact_object) tuples
"""
pass
[docs]
@abstractmethod
def get_fold_artifacts(
self,
step_index: int,
branch_path: Optional[List[int]] = None
) -> List[Tuple[int, Any]]:
"""Get all fold-specific artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter
Returns:
List of (fold_id, artifact_object) tuples, sorted by fold_id
"""
pass
[docs]
@abstractmethod
def has_artifacts_for_step(self, step_index: int) -> bool:
"""Check if artifacts exist for a step.
Args:
step_index: 1-based step index
Returns:
True if artifacts are available for this step
"""
pass
[docs]
def get_primary_artifact(self, step_index: int) -> Optional[Any]:
"""Get the primary artifact for a step.
The primary artifact is typically the main model or transformer
for the step. Default implementation returns the first artifact.
Args:
step_index: 1-based step index
Returns:
Primary artifact object or None if not found
"""
artifacts = self.get_artifacts_for_step(step_index)
if artifacts:
return artifacts[0][1]
return None
[docs]
def get_artifact_by_chain(self, chain_path: str) -> Optional[Any]:
"""Get artifact by V3 chain path (optional V3 method).
Args:
chain_path: Full operator chain path (e.g., "s1.MinMaxScaler>s3.PLS")
Returns:
Artifact object or None if not found
"""
return None # Default implementation returns None
[docs]
def get_artifacts_for_chain_prefix(
self,
chain_prefix: str
) -> List[Tuple[str, Any]]:
"""Get all artifacts matching a chain path prefix (optional V3 method).
Args:
chain_prefix: Chain path prefix to match
Returns:
List of (chain_path, artifact_object) tuples
"""
return [] # Default implementation returns empty list
[docs]
class MapArtifactProvider(ArtifactProvider):
"""In-memory artifact provider backed by a dictionary.
Provides artifacts from a pre-loaded dictionary mapping step indices
to artifacts. Used when artifacts are resolved from an ExecutionTrace
or when loading from a bundle.
Attributes:
artifact_map: Dictionary mapping step_index to list of (artifact_id, object) tuples
fold_weights: Optional fold weights for CV ensemble averaging
Example:
>>> artifact_map = {
... 1: [("0001:1:all", snv_transformer)],
... 2: [("0001:2:0", model_fold0), ("0001:2:1", model_fold1)]
... }
>>> provider = MapArtifactProvider(artifact_map)
>>> transformer = provider.get_artifact(step_index=1)
"""
def __init__(
self,
artifact_map: Dict[int, List[Tuple[str, Any]]],
fold_weights: Optional[Dict[int, float]] = None,
primary_artifacts: Optional[Dict[int, str]] = None
):
"""Initialize map-based artifact provider.
Args:
artifact_map: Mapping of step_index -> list of (artifact_id, object)
fold_weights: Optional fold weights for CV models
primary_artifacts: Optional mapping of step_index -> primary artifact_id
"""
self.artifact_map = artifact_map
self.fold_weights = fold_weights or {}
self.primary_artifacts = primary_artifacts or {}
[docs]
def get_artifact(self, step_index: int, fold_id: Optional[int] = None) -> Optional[Any]:
"""Get a single artifact for a step.
If fold_id is specified, returns the fold-specific artifact.
Otherwise, returns the primary or first artifact.
Args:
step_index: 1-based step index
fold_id: Optional fold ID for fold-specific artifacts
Returns:
Artifact object or None if not found
"""
artifacts = self.artifact_map.get(step_index, [])
if not artifacts:
return None
if fold_id is not None:
# Look for fold-specific artifact
for artifact_id, obj in artifacts:
# Check if artifact_id contains fold info (e.g., "0001:2:0")
parts = artifact_id.split(":")
if len(parts) >= 3:
try:
artifact_fold = int(parts[-1])
if artifact_fold == fold_id:
return obj
except ValueError:
pass
return None
# Return primary artifact if specified
if step_index in self.primary_artifacts:
primary_id = self.primary_artifacts[step_index]
for artifact_id, obj in artifacts:
if artifact_id == primary_id:
return obj
# Return first artifact
return artifacts[0][1] if artifacts else None
[docs]
def get_artifacts_for_step(
self,
step_index: int,
branch_path: Optional[List[int]] = None,
branch_id: Optional[int] = None,
source_index: Optional[int] = None,
substep_index: Optional[int] = None
) -> List[Tuple[str, Any]]:
"""Get all artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter (not used in map provider)
branch_id: Optional branch ID filter (not used in map provider)
source_index: Optional source/dataset index filter (not used in map provider)
substep_index: Optional substep index filter (not used in map provider)
Returns:
List of (artifact_id, artifact_object) tuples
"""
return self.artifact_map.get(step_index, [])
[docs]
def get_fold_artifacts(
self,
step_index: int,
branch_path: Optional[List[int]] = None
) -> List[Tuple[int, Any]]:
"""Get all fold-specific artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter (not used in map provider)
Returns:
List of (fold_id, artifact_object) tuples, sorted by fold_id
"""
artifacts = self.artifact_map.get(step_index, [])
fold_artifacts = []
for artifact_id, obj in artifacts:
parts = artifact_id.split(":")
if len(parts) >= 3:
fold_part = parts[-1]
if fold_part != "all":
try:
fold_id = int(fold_part)
fold_artifacts.append((fold_id, obj))
except ValueError:
pass
return sorted(fold_artifacts, key=lambda x: x[0])
[docs]
def has_artifacts_for_step(self, step_index: int) -> bool:
"""Check if artifacts exist for a step.
Args:
step_index: 1-based step index
Returns:
True if artifacts are available for this step
"""
return step_index in self.artifact_map and len(self.artifact_map[step_index]) > 0
[docs]
def get_fold_weights(self) -> Dict[int, float]:
"""Get fold weights for CV ensemble averaging.
Returns:
Dictionary mapping fold_id to weight
"""
return self.fold_weights.copy()
[docs]
class LoaderArtifactProvider(ArtifactProvider):
"""Artifact provider backed by an ArtifactLoader.
Wraps an ArtifactLoader to provide artifacts on-demand with lazy loading
and caching. Used when loading from a manifest for prediction.
Attributes:
loader: The underlying ArtifactLoader
trace: Optional ExecutionTrace for step-to-artifact mapping
"""
def __init__(
self,
loader: Any, # ArtifactLoader
trace: Optional[Any] = None # ExecutionTrace
):
"""Initialize loader-based artifact provider.
Args:
loader: ArtifactLoader instance for loading artifacts
trace: Optional ExecutionTrace for step mapping
"""
self.loader = loader
self.trace = trace
[docs]
def get_artifact(self, step_index: int, fold_id: Optional[int] = None) -> Optional[Any]:
"""Get a single artifact for a step.
If trace is available, uses trace to find artifact IDs.
Otherwise, uses loader's step-based lookup.
Args:
step_index: 1-based step index
fold_id: Optional fold ID for fold-specific artifacts
Returns:
Artifact object or None if not found
"""
if self.trace is not None:
step = self.trace.get_step(step_index)
if step and step.artifacts:
if fold_id is not None and step.artifacts.fold_artifact_ids:
artifact_id = step.artifacts.fold_artifact_ids.get(fold_id)
if artifact_id:
return self.loader.load_by_id(artifact_id)
elif step.artifacts.primary_artifact_id:
return self.loader.load_by_id(step.artifacts.primary_artifact_id)
elif step.artifacts.artifact_ids:
return self.loader.load_by_id(step.artifacts.artifact_ids[0])
return None
# Fallback: use loader's step-based lookup
artifacts = self.loader.load_for_step(step_index=step_index, fold_id=fold_id)
if artifacts:
return artifacts[0][1]
return None
[docs]
def get_artifacts_for_step(
self,
step_index: int,
branch_path: Optional[List[int]] = None,
branch_id: Optional[int] = None,
source_index: Optional[int] = None,
substep_index: Optional[int] = None
) -> List[Tuple[str, Any]]:
"""Get all artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter
branch_id: Optional branch ID filter
source_index: Optional source index filter for multi-source pipelines
substep_index: Optional substep index filter (not used in loader provider)
Returns:
List of (artifact_id, artifact_object) tuples
"""
# Determine target branch for filtering
target_branch: Optional[int] = None
if branch_path is not None and len(branch_path) > 0:
target_branch = branch_path[0]
elif branch_id is not None:
target_branch = branch_id
if self.trace is not None:
step = self.trace.get_step(step_index)
if step and step.artifacts:
# Get artifact IDs, optionally filtered by source_index
if source_index is not None and step.artifacts.by_source:
# Use by_source index for filtering
artifact_ids = step.artifacts.by_source.get(source_index, [])
else:
# Use all artifact IDs
artifact_ids = step.artifacts.artifact_ids
results = []
for artifact_id in artifact_ids:
try:
# Filter by branch_path using artifact record metadata
if target_branch is not None:
record = self.loader.get_record(artifact_id)
if record is not None:
artifact_branch = None
if record.branch_path and len(record.branch_path) > 0:
artifact_branch = record.branch_path[0]
# Include if: artifact has no branch (shared) or matches target
if artifact_branch is not None and artifact_branch != target_branch:
continue # Skip - wrong branch
obj = self.loader.load_by_id(artifact_id)
results.append((artifact_id, obj))
except (KeyError, FileNotFoundError):
pass
return results
return []
# Fallback: use loader's step-based lookup
return self.loader.load_for_step(step_index=step_index, branch_path=branch_path)
[docs]
def get_fold_artifacts(
self,
step_index: int,
branch_path: Optional[List[int]] = None
) -> List[Tuple[int, Any]]:
"""Get all fold-specific artifacts for a step.
Args:
step_index: 1-based step index
branch_path: Optional branch path filter
Returns:
List of (fold_id, artifact_object) tuples, sorted by fold_id
"""
if self.trace is not None:
step = self.trace.get_step(step_index)
if step and step.artifacts and step.artifacts.fold_artifact_ids:
results = []
for fold_id, artifact_id in step.artifacts.fold_artifact_ids.items():
try:
obj = self.loader.load_by_id(artifact_id)
results.append((fold_id, obj))
except (KeyError, FileNotFoundError):
pass
return sorted(results, key=lambda x: x[0])
return []
# Fallback: use loader's fold model lookup
return self.loader.load_fold_models(step_index=step_index, branch_path=branch_path)
[docs]
def has_artifacts_for_step(self, step_index: int) -> bool:
"""Check if artifacts exist for a step.
Args:
step_index: 1-based step index
Returns:
True if artifacts are available for this step
"""
if self.trace is not None:
step = self.trace.get_step(step_index)
return step is not None and step.has_artifacts()
# Fallback: use loader's step check
return self.loader.has_binaries_for_step(step_index)
[docs]
def get_artifact_by_chain(self, chain_path: str) -> Optional[Any]:
"""Get artifact by V3 chain path.
Args:
chain_path: Full operator chain path (e.g., "s1.MinMaxScaler>s3.PLS")
Returns:
Artifact object or None if not found
"""
if hasattr(self.loader, 'load_by_chain'):
return self.loader.load_by_chain(chain_path)
return None
[docs]
def get_artifacts_for_chain_prefix(
self,
chain_prefix: str
) -> List[Tuple[str, Any]]:
"""Get all artifacts matching a chain path prefix.
Args:
chain_prefix: Chain path prefix to match
Returns:
List of (chain_path, artifact_object) tuples
"""
if hasattr(self.loader, 'load_by_chain_prefix'):
return self.loader.load_by_chain_prefix(chain_prefix)
return []
[docs]
class ExecutionContext:
"""
Composite execution context with extensibility.
This class combines the three context components and provides:
- Immutable data selection via DataSelector
- Mutable state tracking via PipelineState
- Controller coordination via StepMetadata
- Custom data storage for controller-specific needs
The context supports deep copying for controller isolation while sharing
processing chains between selector and operations.
Attributes:
selector: Immutable data selector
state: Mutable pipeline state
metadata: Mutable step metadata
custom: Dict for controller-specific custom data
aggregate_column: Sample aggregation column for prediction aggregation.
- None: No aggregation (default)
- 'y': Aggregate by y_true values
- str: Aggregate by specified metadata column
Example:
>>> context = ExecutionContext(
... selector=DataSelector(partition="train"),
... state=PipelineState(y_processing="numeric"),
... metadata=StepMetadata(keyword="transform")
... )
>>> context.custom["my_controller"] = {"threshold": 0.5}
>>> train_ctx = context.with_partition("train")
"""
def __init__(
self,
selector: Optional[DataSelector] = None,
state: Optional[PipelineState] = None,
metadata: Optional[StepMetadata] = None,
custom: Optional[Dict[str, Any]] = None,
aggregate_column: Optional[str] = None
):
"""
Initialize execution context.
Args:
selector: Data selector (default: DataSelector())
state: Pipeline state (default: PipelineState())
metadata: Step metadata (default: StepMetadata())
custom: Custom data dict (default: {})
aggregate_column: Sample aggregation column for prediction aggregation.
- None: No aggregation (default)
- 'y': Aggregate by y_true values
- str: Aggregate by specified metadata column
"""
self.selector = selector if selector is not None else DataSelector()
self.state = state if state is not None else PipelineState()
self.metadata = metadata if metadata is not None else StepMetadata()
self.custom = custom if custom is not None else {}
self.aggregate_column = aggregate_column
[docs]
def copy(self) -> "ExecutionContext":
"""
Create a deep copy of this context.
This preserves the copy semantics expected by controllers.
Returns:
Deep copy of ExecutionContext
"""
return ExecutionContext(
selector=self.selector.copy(),
state=self.state.copy(),
metadata=self.metadata.copy(),
custom=deepcopy(self.custom),
aggregate_column=self.aggregate_column
)
[docs]
def with_partition(self, partition: str) -> "ExecutionContext":
"""
Create new context with updated partition.
Args:
partition: New partition value
Returns:
New ExecutionContext with updated partition
"""
new_ctx = self.copy()
new_ctx.selector = new_ctx.selector.with_partition(partition)
return new_ctx
[docs]
def with_processing(self, processing: List[List[str]]) -> "ExecutionContext":
"""
Create new context with updated processing chains.
Args:
processing: New processing chains
Returns:
New ExecutionContext with updated processing
"""
new_ctx = self.copy()
new_ctx.selector = new_ctx.selector.with_processing(processing)
return new_ctx
[docs]
def with_layout(self, layout: str) -> "ExecutionContext":
"""
Create new context with updated layout.
Args:
layout: New layout value
Returns:
New ExecutionContext with updated layout
"""
new_ctx = self.copy()
new_ctx.selector = new_ctx.selector.with_layout(layout)
return new_ctx
[docs]
def with_step_number(self, step_number: int) -> "ExecutionContext":
"""
Create new context with updated step number.
Args:
step_number: New step number
Returns:
New ExecutionContext with updated step number
"""
new_ctx = self.copy()
new_ctx.state = dataclass_replace(new_ctx.state, step_number=step_number)
return new_ctx
[docs]
def with_y(self, y_processing: str) -> "ExecutionContext":
"""
Create new context with updated y processing.
Args:
y_processing: New y processing value
Returns:
New ExecutionContext with updated y processing
"""
new_ctx = self.copy()
new_ctx.state = dataclass_replace(new_ctx.state, y_processing=y_processing)
return new_ctx
[docs]
def with_metadata(self, **kwargs) -> "ExecutionContext":
"""
Create new context with updated metadata fields.
Args:
**kwargs: Metadata fields to update
Returns:
New ExecutionContext with updated metadata
"""
new_ctx = self.copy()
new_ctx.metadata = dataclass_replace(new_ctx.metadata, **kwargs)
return new_ctx
[docs]
def with_branch(
self,
branch_id: Optional[int] = None,
branch_name: Optional[str] = None
) -> "ExecutionContext":
"""
Create new context with updated branch information.
Args:
branch_id: Branch identifier (0-indexed)
branch_name: Human-readable branch name
Returns:
New ExecutionContext with updated branch info
"""
new_ctx = self.copy()
new_ctx.selector = new_ctx.selector.with_branch(branch_id, branch_name)
return new_ctx
[docs]
def get_selector(self) -> DataSelector:
"""
Get the data selector.
Returns:
DataSelector instance
"""
return self.selector
[docs]
@dataclass
class RuntimeContext:
"""
Runtime infrastructure components for pipeline execution.
This class holds references to infrastructure components that are needed
during execution but are not part of the data flow or pipeline state.
It replaces the "God Object" pattern of passing the runner everywhere.
Attributes:
saver: SimulationSaver for file operations
manifest_manager: ManifestManager for pipeline tracking
artifact_loader: ArtifactLoader for predict/explain modes
artifact_provider: ArtifactProvider for controller-agnostic artifact injection (Phase 3)
artifact_registry: ArtifactRegistry for artifact management (v2 system)
pipeline_uid: Current pipeline unique identifier
step_runner: StepRunner for executing sub-steps
operation_count: Counter for operation IDs
substep_number: Current substep number
trace_recorder: TraceRecorder for recording execution traces (Phase 2)
retrain_config: RetrainConfig for retrain mode control (Phase 7)
"""
saver: Any = None
manifest_manager: Any = None
artifact_loader: Any = None
artifact_provider: Optional["ArtifactProvider"] = None # Phase 3: controller-agnostic artifact injection
artifact_registry: Any = None
pipeline_uid: Optional[str] = None
step_runner: Any = None
step_number: int = 0
operation_count: int = 0
substep_number: int = -1
processing_counter: int = 0 # Global counter for unique processing indices within a step
artifact_load_counter: Dict[int, int] = field(default_factory=dict) # Per-source artifact load counter
target_model: Optional[Dict[str, Any]] = None
explainer: Any = None
trace_recorder: Any = None # TraceRecorder instance for execution trace recording
retrain_config: Any = None # Phase 7: RetrainConfig for retrain mode control
[docs]
def next_op(self) -> int:
"""Get the next operation ID."""
self.operation_count += 1
return self.operation_count
[docs]
def next_processing_index(self) -> int:
"""Get the next unique processing index for artifact identification.
This counter persists across all sub-operations within a step (e.g.,
feature_augmentation), ensuring each transformer gets a unique substep_index.
The counter is reset at the start of each step.
Returns:
int: A unique processing index within the current step.
"""
idx = self.processing_counter
self.processing_counter += 1
return idx
[docs]
def reset_processing_counter(self) -> None:
"""Reset the processing counter at the start of each step."""
self.processing_counter = 0
self.artifact_load_counter = {} # Also reset artifact load counter
[docs]
def next_artifact_load_index(self, source_index: int) -> int:
"""Get the next artifact load index for a source during prediction.
This counter tracks how many artifacts have been loaded for each source
across all sub-operations within a step (e.g., feature_augmentation).
Args:
source_index: The source index to track.
Returns:
int: The next artifact index to load for this source.
"""
if source_index not in self.artifact_load_counter:
self.artifact_load_counter[source_index] = 0
idx = self.artifact_load_counter[source_index]
self.artifact_load_counter[source_index] += 1
return idx
[docs]
def should_train_step(self, step_index: int, is_model: bool = False) -> bool:
"""Determine if a step should train based on retrain configuration.
Phase 7 Feature:
When a retrain_config is present, this method delegates to it
to determine whether a step should train or use existing artifacts.
Args:
step_index: 1-based step index
is_model: Whether this is the model step
Returns:
True if the step should train, False if it should use existing artifacts
"""
if self.retrain_config is None:
# No retrain config: default behavior based on artifact_provider
if self.artifact_provider is not None:
# If we have artifacts for this step, use them (don't train)
return not self.artifact_provider.has_artifacts_for_step(step_index)
# No artifacts available: train
return True
# Delegate to retrain_config
return self.retrain_config.should_train_step(step_index, is_model)
[docs]
def record_step_start(
self,
step_index: int,
operator_type: str = "",
operator_class: str = "",
operator_config: Optional[Dict[str, Any]] = None,
branch_path: Optional[List[int]] = None,
branch_name: str = "",
mode: str = "train"
) -> None:
"""Record the start of a step execution in the trace.
Args:
step_index: 1-based step index
operator_type: Type of operator (e.g., "transform", "model")
operator_class: Class name of operator
operator_config: Serialized operator configuration
branch_path: Branch indices if in branch context
branch_name: Human-readable branch name
mode: Execution mode ("train", "predict", "explain")
"""
if self.trace_recorder is not None:
from nirs4all.pipeline.trace import StepExecutionMode
exec_mode = (
StepExecutionMode.PREDICT if mode in ("predict", "explain")
else StepExecutionMode.TRAIN
)
self.trace_recorder.start_step(
step_index=step_index,
operator_type=operator_type,
operator_class=operator_class,
operator_config=operator_config,
execution_mode=exec_mode,
branch_path=branch_path,
branch_name=branch_name
)
[docs]
def record_step_artifact(
self,
artifact_id: str,
is_primary: bool = False,
fold_id: Optional[int] = None,
chain_path: Optional[str] = None,
branch_path: Optional[List[int]] = None,
source_index: Optional[int] = None,
metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Record an artifact created during the current step (V3).
Args:
artifact_id: The V3 artifact ID
is_primary: Whether this is the primary artifact (e.g., model)
fold_id: CV fold ID if fold-specific artifact
chain_path: V3 operator chain path for this artifact
branch_path: Branch path for indexing
source_index: Source index for multi-source artifacts
metadata: Additional artifact metadata
"""
if self.trace_recorder is not None:
self.trace_recorder.record_artifact(
artifact_id=artifact_id,
is_primary=is_primary,
fold_id=fold_id,
chain_path=chain_path,
branch_path=branch_path,
source_index=source_index,
metadata=metadata
)
[docs]
def record_step_end(
self,
is_model: bool = False,
fold_weights: Optional[Dict[int, float]] = None,
skip_trace: bool = False
) -> None:
"""Record the end of a step execution.
Args:
is_model: Whether this is the model step
fold_weights: Per-fold weights for CV models
skip_trace: If True, don't add this step to the trace
"""
if self.trace_recorder is not None:
self.trace_recorder.end_step(
is_model=is_model,
fold_weights=fold_weights,
skip_trace=skip_trace
)
[docs]
def record_input_shapes(
self,
input_shape: Optional[tuple] = None,
features_shape: Optional[List[tuple]] = None
) -> None:
"""Record input shapes for the current step.
Args:
input_shape: 2D layout shape (samples, features)
features_shape: List of 3D shapes per source (samples, processings, features)
"""
if self.trace_recorder is not None:
self.trace_recorder.record_input_shapes(
input_shape=input_shape,
features_shape=features_shape
)
[docs]
def record_output_shapes(
self,
output_shape: Optional[tuple] = None,
features_shape: Optional[List[tuple]] = None
) -> None:
"""Record output shapes for the current step.
Args:
output_shape: 2D layout shape (samples, features)
features_shape: List of 3D shapes per source (samples, processings, features)
"""
if self.trace_recorder is not None:
self.trace_recorder.record_output_shapes(
output_shape=output_shape,
features_shape=features_shape
)
[docs]
def get_trace_id(self) -> Optional[str]:
"""Get the current trace ID.
Returns:
Trace ID or None if no trace recorder
"""
if self.trace_recorder is not None:
return self.trace_recorder.trace_id
return None
[docs]
def get_execution_trace(self) -> Optional[Any]:
"""Get the current execution trace.
Returns the trace object that has been built during execution.
This can be used to generate post-execution diagrams with actual shapes.
Returns:
ExecutionTrace object or None if no trace recorder
"""
if self.trace_recorder is not None:
return self.trace_recorder.trace
return None