Source code for nirs4all.controllers.transforms.transformer
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING, Union
from sklearn.base import TransformerMixin
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.operators.base import SpectraTransformerMixin
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.pipeline.storage.artifacts.types import ArtifactType
if TYPE_CHECKING:
from nirs4all.spectra.spectra_dataset import SpectroDataset
from nirs4all.pipeline.steps.parser import ParsedStep
import numpy as np
from sklearn.base import clone
import pickle
## TODO add parrallel support for multi-source datasets and multi-processing datasets
[docs]
@register_controller
class TransformerMixinController(OperatorController):
priority = 10
@staticmethod
def _needs_wavelengths(operator: Any) -> bool:
"""Check if the operator requires wavelengths.
Args:
operator: The operator to check.
Returns:
True if the operator is a SpectraTransformerMixin with _requires_wavelengths=True.
"""
return (
isinstance(operator, SpectraTransformerMixin) and
getattr(operator, '_requires_wavelengths', False)
)
@staticmethod
def _extract_wavelengths(
dataset: 'SpectroDataset',
source_index: int,
operator_name: str
) -> np.ndarray:
"""Extract wavelengths from dataset for a given source.
Attempts to get wavelengths using dataset.wavelengths_nm(). If that fails,
falls back to dataset.float_headers() as a legacy fallback.
Args:
dataset: The SpectroDataset to extract wavelengths from.
source_index: The source index for multi-source datasets.
operator_name: Name of the operator (for error messages).
Returns:
Wavelength array in nm.
Raises:
ValueError: If wavelengths cannot be extracted from the dataset.
"""
try:
wavelengths = dataset.wavelengths_nm(source_index)
return wavelengths
except (ValueError, AttributeError):
pass
# Fall back to inferring from headers
try:
wavelengths = dataset.float_headers(source_index)
if wavelengths is not None and len(wavelengths) > 0:
return wavelengths
except (ValueError, AttributeError):
pass
raise ValueError(
f"Operator {operator_name} requires wavelengths but dataset has no "
f"wavelength information for source {source_index}. Ensure the dataset "
f"has wavelength headers (nm or cm⁻¹)."
)
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Match TransformerMixin objects."""
# Get the actual model object
model_obj = None
if isinstance(step, dict) and 'model' in step:
model_obj = step['model']
elif operator is not None:
model_obj = operator
else:
model_obj = step
# Check if it's a TransformerMixin
return (isinstance(model_obj, TransformerMixin) or
(hasattr(model_obj, '__class__') and issubclass(model_obj.__class__, TransformerMixin)))
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Check if the operator supports multi-source datasets."""
return True
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""TransformerMixin controllers support prediction mode."""
return True
[docs]
def execute(
self,
step_info: 'ParsedStep',
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
source: int = -1,
mode: str = "train",
loaded_binaries: Optional[List[Tuple[str, Any]]] = None,
prediction_store: Optional[Any] = None
):
"""Execute transformer - handles normal, feature augmentation, and sample augmentation modes.
Supports optional `fit_on_all` parameter in step configuration to fit the transformer
on all data instead of just training data. This is useful for unsupervised preprocessing
where you want the transformation to capture the full data distribution.
Step format:
# Standard (fit on train, transform all):
StandardScaler()
# Fit on ALL data (unsupervised preprocessing):
{"preprocessing": StandardScaler(), "fit_on_all": True}
"""
op = step_info.operator
# Extract fit_on_all option from step configuration
fit_on_all = False
if isinstance(step_info.original_step, dict):
fit_on_all = step_info.original_step.get("fit_on_all", False)
# Check if we're in sample augmentation mode
if context.metadata.augment_sample and mode not in ["predict", "explain"]:
return self._execute_for_sample_augmentation(
op, dataset, context, runtime_context, mode, loaded_binaries, prediction_store,
fit_on_all=fit_on_all
)
# Normal or feature augmentation execution (existing code)
operator_name = op.__class__.__name__
# Get all data (always needed for transform)
# IMPORTANT: Include excluded samples to maintain consistent array shapes
# when replacing features. Excluded samples are filtered at query time, not transform time.
all_data = dataset.x(context.selector, "3d", concat_source=False, include_excluded=True)
# Get fitting data based on fit_on_all option
# Note: Fitting should EXCLUDE filtered samples to prevent outlier influence
if fit_on_all:
# Fit on all data (unsupervised preprocessing) but exclude filtered samples
fit_data = dataset.x(context.selector, "3d", concat_source=False, include_excluded=False)
else:
# Standard: fit on train data only (excluding filtered samples)
train_context = context.with_partition("train")
fit_data = dataset.x(train_context.selector, "3d", concat_source=False, include_excluded=False)
# Ensure data is in list format
if not isinstance(fit_data, list):
fit_data = [fit_data]
if not isinstance(all_data, list):
all_data = [all_data]
fitted_transformers = []
transformed_features_list = []
new_processing_names = []
processing_names = []
# Note: We use runtime_context.next_processing_index() to track processing counter
# across all sources for unique artifact IDs. This ensures each (source, processing)
# pair gets a unique substep_index even across feature_augmentation sub-operations.
# Check if operator needs wavelengths (once, outside source loop)
needs_wavelengths = self._needs_wavelengths(op)
# Loop through each data source
for sd_idx, (fit_x, all_x) in enumerate(zip(fit_data, all_data)):
# print(f"Processing source {sd_idx}: fit shape {fit_x.shape}, all shape {all_x.shape}")
# Extract wavelengths for this source if needed
wavelengths = None
if needs_wavelengths:
wavelengths = self._extract_wavelengths(dataset, sd_idx, operator_name)
# Get processing names for this source
processing_ids = dataset.features_processings(sd_idx)
source_processings = processing_ids
# print("🔹 Processing source", sd_idx, "with processings:", source_processings)
if context.selector.processing:
# Handle case where processing list has fewer entries than sources
# (e.g., after source merge, only source 0 has processings)
if sd_idx < len(context.selector.processing):
source_processings = context.selector.processing[sd_idx]
else:
# Skip this source - it was merged into source 0
continue
source_transformed_features = []
source_new_processing_names = []
source_processing_names = []
# Loop through each processing in the 3D data (samples, processings, features)
for processing_idx in range(fit_x.shape[1]):
processing_name = processing_ids[processing_idx]
# print(f" Processing {processing_name} (idx {processing_idx})")
# print(processing_name, processing_name in source_processings)
if processing_name not in source_processings:
continue
fit_2d = fit_x[:, processing_idx, :] # Data for fitting
all_2d = all_x[:, processing_idx, :] # All data to transform
# print(f" Processing {processing_name} (idx {processing_idx}): fit {fit_2d.shape}, all {all_2d.shape}")
if mode == "predict" or mode == "explain":
transformer = None
loaded_artifact_name = None
# V3: Use artifact_provider for chain-based loading
if runtime_context.artifact_provider is not None:
step_index = runtime_context.step_number
# Load all artifacts for this source, then pick by global index
# The global index persists across feature_augmentation sub-operations
step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
step_index,
branch_path=context.selector.branch_path,
source_index=sd_idx,
substep_index=None # Load all artifacts for this source
)
if step_artifacts:
artifacts_list = list(step_artifacts)
# Use global artifact load index for this source to handle
# feature_augmentation sub-operations correctly
artifact_idx = runtime_context.next_artifact_load_index(sd_idx)
if artifact_idx < len(artifacts_list):
loaded_artifact_name, transformer = artifacts_list[artifact_idx]
# Use the artifact name from what was actually loaded, not next_op()
if loaded_artifact_name:
new_operator_name = loaded_artifact_name
else:
# Fallback: generate name for error message
new_operator_name = f"{operator_name}_{runtime_context.next_op()}"
if transformer is None:
available = []
if runtime_context.artifact_provider is not None:
step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
runtime_context.step_number,
branch_path=context.selector.branch_path
)
available = [name for name, _ in step_artifacts] if step_artifacts else []
raise ValueError(
f"Transformer for {operator_name} not found at step {runtime_context.step_number} "
f"(branch_path={context.selector.branch_path}, source={sd_idx}, "
f"artifact_idx={artifact_idx if 'artifact_idx' in dir() else 'N/A'}). "
f"Available artifacts: {available}"
)
else:
new_operator_name = f"{operator_name}_{runtime_context.next_op()}"
transformer = clone(op)
if needs_wavelengths:
transformer.fit(fit_2d, wavelengths=wavelengths)
else:
transformer.fit(fit_2d)
if needs_wavelengths:
transformed_2d = transformer.transform(all_2d, wavelengths=wavelengths)
else:
transformed_2d = transformer.transform(all_2d)
# print(" Transformed shape:", transformed_2d.shape)
# Store results
source_transformed_features.append(transformed_2d)
new_processing_name = f"{processing_name}_{new_operator_name}"
source_new_processing_names.append(new_processing_name)
source_processing_names.append(processing_name)
# Persist fitted transformer using artifact registry
if mode == "train":
artifact = self._persist_transformer(
runtime_context=runtime_context,
transformer=transformer,
name=new_operator_name,
context=context,
source_index=sd_idx,
processing_index=runtime_context.next_processing_index()
)
fitted_transformers.append(artifact)
# print("🔹 Finished processing source", sd_idx, len(fitted_transformers))
# ("🔹 New processing names:", source_new_processing_names)
transformed_features_list.append(source_transformed_features)
new_processing_names.append(source_new_processing_names)
processing_names.append(source_processing_names)
for sd_idx, (source_features, src_new_processing_names) in enumerate(zip(transformed_features_list, new_processing_names)):
if context.metadata.add_feature:
dataset.add_features(source_features, src_new_processing_names, source=sd_idx)
# Update processing in context (requires creating new list)
new_processing = list(context.selector.processing)
new_processing[sd_idx] = src_new_processing_names
context = context.with_processing(new_processing)
else:
dataset.replace_features(
source_processings=processing_names[sd_idx],
features=source_features,
processings=src_new_processing_names,
source=sd_idx
)
# Update processing in context (requires creating new list)
new_processing = list(context.selector.processing)
new_processing[sd_idx] = src_new_processing_names
context = context.with_processing(new_processing)
context = context.with_metadata(add_feature=False)
# print(dataset)
return context, fitted_transformers
def _execute_for_sample_augmentation(
self,
operator: Any,
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
mode: str,
loaded_binaries: Optional[List[Tuple[str, Any]]],
prediction_store: Optional[Any],
fit_on_all: bool = False
) -> Tuple[ExecutionContext, List]:
"""
Apply transformer to origin samples and add augmented samples.
Optimized implementation:
- Batch data fetching: fetches all target samples in one call
- Single transformer fit: fits transformer once on train/all data, reuses for all samples
- Batch transform: transforms all samples at once per processing
- Bulk insert: adds all augmented samples in a loop but with pre-fitted transformer
Args:
operator: The transformer operator to apply
dataset: The dataset to operate on
context: Execution context
runtime_context: Runtime context with saver, step info, etc.
mode: Execution mode ("train", "predict", "explain")
loaded_binaries: Pre-loaded binaries for predict/explain mode
prediction_store: Not used
fit_on_all: If True, fit transformer on all data instead of train only
"""
target_sample_ids = context.metadata.target_samples
if not target_sample_ids:
return context, []
operator_name = operator.__class__.__name__
fitted_transformers = []
n_targets = len(target_sample_ids)
# Check if operator needs wavelengths
needs_wavelengths = self._needs_wavelengths(operator)
wavelengths_cache = {} # Cache wavelengths per source
# Get data for fitting (if not in predict/explain mode) - once for all samples
fit_data = None
fitted_transformers_cache = {} # Cache fitted transformers per source/processing
if mode not in ["predict", "explain"]:
if fit_on_all:
# Fit on all data (unsupervised preprocessing)
fit_selector = context.selector.with_augmented(False)
else:
# Standard: fit on train data only
train_context = context.with_partition("train")
fit_selector = train_context.selector.with_augmented(False)
fit_data = dataset.x(fit_selector, "3d", concat_source=False)
if not isinstance(fit_data, list):
fit_data = [fit_data]
# Batch fetch all target samples at once
batch_selector = {"sample": list(target_sample_ids)}
all_origin_data = dataset.x(batch_selector, "3d", concat_source=False, include_augmented=False)
if not isinstance(all_origin_data, list):
all_origin_data = [all_origin_data]
# Determine dimensions - use actual data shape, not target_sample_ids length
n_sources = len(all_origin_data)
n_actual_samples = all_origin_data[0].shape[0] if n_sources > 0 else 0
n_processings = all_origin_data[0].shape[1] if n_sources > 0 else 0
# Ensure we have the expected number of samples
if n_actual_samples != n_targets:
# If mismatch, fallback to original sample-by-sample approach
# This can happen if some target_sample_ids don't exist or are filtered out
return self._execute_for_sample_augmentation_sequential(
operator, dataset, context, runtime_context, mode, loaded_binaries, prediction_store,
fit_on_all=fit_on_all
)
# Pre-fit and cache transformers for each source/processing combination (once!)
if mode not in ["predict", "explain"] and fit_data:
for source_idx in range(n_sources):
# Extract wavelengths for this source if needed
wavelengths = None
if needs_wavelengths:
if source_idx not in wavelengths_cache:
wavelengths_cache[source_idx] = self._extract_wavelengths(
dataset, source_idx, operator_name
)
wavelengths = wavelengths_cache[source_idx]
for proc_idx in range(n_processings):
cache_key = (source_idx, proc_idx)
transformer = clone(operator)
fit_proc_data = fit_data[source_idx][:, proc_idx, :]
if needs_wavelengths:
transformer.fit(fit_proc_data, wavelengths=wavelengths)
else:
transformer.fit(fit_proc_data)
fitted_transformers_cache[cache_key] = transformer
# Save a single transformer binary per source/processing (not per sample)
if mode == "train":
artifact = self._persist_transformer(
runtime_context=runtime_context,
transformer=transformer,
name=f"{operator_name}_{source_idx}_{proc_idx}",
context=context,
source_index=source_idx
)
fitted_transformers.append(artifact)
# Batch transform all samples per source/processing
# all_origin_data[source_idx] shape: (n_samples, n_processings, n_features)
all_transformed = [] # List[List[ndarray]]: [source][processing] -> (n_samples, n_features)
for source_idx in range(n_sources):
source_transformed = []
source_data = all_origin_data[source_idx] # (n_samples, n_processings, n_features)
# Extract wavelengths for this source if needed (for predict/explain mode)
wavelengths = None
if needs_wavelengths:
if source_idx not in wavelengths_cache:
wavelengths_cache[source_idx] = self._extract_wavelengths(
dataset, source_idx, operator_name
)
wavelengths = wavelengths_cache[source_idx]
for proc_idx in range(n_processings):
proc_data = source_data[:, proc_idx, :] # (n_samples, n_features)
if mode in ["predict", "explain"]:
transformer = None
artifact_key = f"{operator_name}_{source_idx}_{proc_idx}"
# V3: Use artifact_provider for chain-based loading
if runtime_context.artifact_provider is not None:
step_index = runtime_context.step_number
step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
step_index,
branch_path=context.selector.branch_path,
source_index=source_idx
)
if step_artifacts:
artifacts_dict = dict(step_artifacts)
transformer = artifacts_dict.get(artifact_key)
# Also try matching by proc_idx position if name doesn't match
if transformer is None:
artifacts_list = list(step_artifacts)
if proc_idx < len(artifacts_list):
_, transformer = artifacts_list[proc_idx]
if transformer is None:
raise ValueError(f"Transformer for {artifact_key} not found at step {runtime_context.step_number}")
else:
# Use pre-fitted transformer from cache
cache_key = (source_idx, proc_idx)
transformer = fitted_transformers_cache[cache_key]
# Batch transform all samples at once
if needs_wavelengths:
transformed_data = transformer.transform(proc_data, wavelengths=wavelengths)
else:
transformed_data = transformer.transform(proc_data)
source_transformed.append(transformed_data)
all_transformed.append(source_transformed)
# OPTIMIZED: Collect all augmented samples, then batch insert
# Build 3D arrays for batch insertion: (n_samples, n_processings, n_features)
if n_sources == 1:
# Single source: stack transformed data into 3D array
# all_transformed[0] is list of (n_samples, n_features) arrays, one per processing
batch_data = np.stack(all_transformed[0], axis=1) # (n_samples, n_processings, n_features)
else:
# Multi-source: create list of 3D arrays
batch_data = []
for source_idx in range(n_sources):
source_3d = np.stack(all_transformed[source_idx], axis=1)
batch_data.append(source_3d)
# Build index dictionaries for all samples
indexes_list = [
{
"partition": "train",
"origin": sample_id,
"augmentation": operator_name
}
for sample_id in target_sample_ids
]
# Single batch insert - O(N) instead of O(N²)
dataset.add_samples_batch(data=batch_data, indexes_list=indexes_list)
return context, fitted_transformers
def _execute_for_sample_augmentation_sequential(
self,
operator: Any,
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
mode: str,
loaded_binaries: Optional[List[Tuple[str, Any]]],
prediction_store: Optional[Any],
fit_on_all: bool = False
) -> Tuple[ExecutionContext, List]:
"""
Fallback sequential implementation for sample augmentation.
Used when batch processing is not possible due to data shape mismatches.
Args:
operator: The transformer operator to apply
dataset: The dataset to operate on
context: Execution context
runtime_context: Runtime context with saver, step info, etc.
mode: Execution mode ("train", "predict", "explain")
loaded_binaries: Pre-loaded binaries for predict/explain mode
prediction_store: Not used
fit_on_all: If True, fit transformer on all data instead of train only
"""
target_sample_ids = context.metadata.target_samples
if not target_sample_ids:
return context, []
operator_name = operator.__class__.__name__
fitted_transformers = []
fitted_transformers_cache = {}
wavelengths_cache = {} # Cache wavelengths per source
# Check if operator needs wavelengths
needs_wavelengths = self._needs_wavelengths(operator)
# Get data for fitting (if not in predict/explain mode)
fit_data = None
if mode not in ["predict", "explain"]:
if fit_on_all:
# Fit on all data (unsupervised preprocessing)
fit_selector = context.selector.with_augmented(False)
else:
# Standard: fit on train data only
train_context = context.with_partition("train")
fit_selector = train_context.selector.with_augmented(False)
fit_data = dataset.x(fit_selector, "3d", concat_source=False)
if not isinstance(fit_data, list):
fit_data = [fit_data]
# Process each target sample
for sample_id in target_sample_ids:
# Get origin sample data (all sources, base samples only)
origin_selector = {"sample": [sample_id]}
origin_data = dataset.x(origin_selector, "3d", concat_source=False, include_augmented=False)
if not isinstance(origin_data, list):
origin_data = [origin_data]
# Transform each source
transformed_sources = []
for source_idx, source_data in enumerate(origin_data):
source_2d_list = []
# Extract wavelengths for this source if needed
wavelengths = None
if needs_wavelengths:
if source_idx not in wavelengths_cache:
wavelengths_cache[source_idx] = self._extract_wavelengths(
dataset, source_idx, operator_name
)
wavelengths = wavelengths_cache[source_idx]
for proc_idx in range(source_data.shape[1]):
proc_data = source_data[:, proc_idx, :]
cache_key = (source_idx, proc_idx)
if mode in ["predict", "explain"]:
transformer = None
artifact_key = f"{operator_name}_{source_idx}_{proc_idx}"
# V3: Use artifact_provider for chain-based loading
if runtime_context.artifact_provider is not None:
step_index = runtime_context.step_number
step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
step_index,
branch_path=context.selector.branch_path,
source_index=source_idx
)
if step_artifacts:
artifacts_dict = dict(step_artifacts)
transformer = artifacts_dict.get(artifact_key)
# Also try matching by proc_idx position if name doesn't match
if transformer is None:
artifacts_list = list(step_artifacts)
if proc_idx < len(artifacts_list):
_, transformer = artifacts_list[proc_idx]
if transformer is None:
raise ValueError(f"Transformer for {artifact_key} not found at step {runtime_context.step_number}")
elif cache_key in fitted_transformers_cache:
# Reuse already fitted transformer
transformer = fitted_transformers_cache[cache_key]
else:
transformer = clone(operator)
if fit_data:
fit_proc_data = fit_data[source_idx][:, proc_idx, :]
if needs_wavelengths:
transformer.fit(fit_proc_data, wavelengths=wavelengths)
else:
transformer.fit(fit_proc_data)
fitted_transformers_cache[cache_key] = transformer
# Save transformer binary once
if mode == "train":
artifact = self._persist_transformer(
runtime_context=runtime_context,
transformer=transformer,
name=f"{operator_name}_{source_idx}_{proc_idx}",
context=context,
source_index=source_idx
)
fitted_transformers.append(artifact)
if needs_wavelengths:
transformed_data = transformer.transform(proc_data, wavelengths=wavelengths)
else:
transformed_data = transformer.transform(proc_data)
source_2d_list.append(transformed_data)
source_3d = np.stack(source_2d_list, axis=1)
transformed_sources.append(source_3d)
# Build index dictionary for the augmented sample
index_dict = {
"partition": "train",
"origin": sample_id,
"augmentation": operator_name
}
if len(transformed_sources) == 1:
data_to_add = transformed_sources[0][0, :, :]
else:
data_to_add = [src[0, :, :] for src in transformed_sources]
dataset.add_samples(data=data_to_add, indexes=index_dict)
return context, fitted_transformers
def _persist_transformer(
self,
runtime_context: 'RuntimeContext',
transformer: Any,
name: str,
context: ExecutionContext,
source_index: Optional[int] = None,
processing_index: Optional[int] = None
) -> Any:
"""Persist fitted transformer using V3 chain-based artifact registry.
Uses artifact_registry.register() with V3 chain-based identification
for complete execution path tracking, including multi-source support.
Args:
runtime_context: Runtime context with saver/registry instances.
transformer: Fitted transformer to persist.
name: Operator name for the transformer (e.g., "StandardScaler_3").
context: Execution context with branch information.
source_index: Source index for multi-source transformers.
processing_index: Index of processing within source (for multi-processing steps).
Returns:
ArtifactRecord with V3 chain-based metadata.
"""
# Use artifact registry (V3 system)
if runtime_context.artifact_registry is not None:
registry = runtime_context.artifact_registry
pipeline_id = runtime_context.saver.pipeline_id if runtime_context.saver else "unknown"
step_index = runtime_context.step_number
branch_path = context.selector.branch_path or []
# Use processing_index for substep_index to ensure unique artifact IDs
# for each processing within a source. This is critical for multi-source
# pipelines with feature augmentation where multiple transformers are
# fit per source. Falls back to substep_number for branch contexts.
if processing_index is not None:
substep_index = processing_index
elif runtime_context.substep_number >= 0:
substep_index = runtime_context.substep_number
else:
substep_index = None
# V3: Build operator chain for this artifact
from nirs4all.pipeline.storage.artifacts.operator_chain import OperatorNode, OperatorChain
# Get the current chain from trace recorder or build new one
if runtime_context.trace_recorder is not None:
current_chain = runtime_context.trace_recorder.current_chain()
else:
current_chain = OperatorChain(pipeline_id=pipeline_id)
# Create node for this transformer with source_index for multi-source
transformer_node = OperatorNode(
step_index=step_index,
operator_class=transformer.__class__.__name__,
branch_path=branch_path,
source_index=source_index,
fold_id=None, # Transformers are shared across folds
substep_index=substep_index,
)
# Build chain path for this artifact
artifact_chain = current_chain.append(transformer_node)
chain_path = artifact_chain.to_path()
# Generate V3 artifact ID using chain
artifact_id = registry.generate_id(chain_path, None, pipeline_id)
# Register artifact with V3 chain tracking
record = registry.register(
obj=transformer,
artifact_id=artifact_id,
artifact_type=ArtifactType.TRANSFORMER,
format_hint='sklearn',
chain_path=chain_path,
source_index=source_index,
)
# Record artifact in execution trace with V3 chain info
runtime_context.record_step_artifact(
artifact_id=artifact_id,
is_primary=False, # Transformers are not primary artifacts
fold_id=None,
chain_path=chain_path,
branch_path=branch_path,
source_index=source_index,
metadata={"class_name": transformer.__class__.__name__, "name": name}
)
return record
# No registry available - skip persistence (for unit tests)
# In production, artifact_registry should always be set by the runner
return None