Source code for nirs4all.controllers.transforms.y_transformer
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from sklearn.base import TransformerMixin
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
from nirs4all.pipeline.storage.artifacts.types import ArtifactType
if TYPE_CHECKING:
from nirs4all.data.dataset import SpectroDataset
from nirs4all.pipeline.steps.parser import ParsedStep
import numpy as np
def _is_transformer_like(obj: Any) -> bool:
"""Check if object is a TransformerMixin instance or class."""
# Instance check
if isinstance(obj, TransformerMixin):
return True
# Class check (e.g., StandardScaler without parentheses)
if isinstance(obj, type) and issubclass(obj, TransformerMixin):
return True
# Fallback for edge cases
if hasattr(obj, '__class__') and hasattr(obj.__class__, '__mro__'):
return TransformerMixin in obj.__class__.__mro__
return False
[docs]
@register_controller
class YTransformerMixinController(OperatorController):
"""
Controller for applying sklearn TransformerMixin operators to targets (y) instead of features (X).
Triggered by the "y_processing" keyword and applies transformations to target data,
fitting on train targets and transforming all target data.
Supports both single transformers and chained transformers (list syntax):
- Single: {"y_processing": StandardScaler()}
- Chained: {"y_processing": [StandardScaler, QuantileTransformer(n_quantiles=30)]}
When using chained transformers, each transformer is applied sequentially,
with proper ancestry tracking and individual artifact persistence for prediction mode.
"""
priority = 5
[docs]
@classmethod
def matches(cls, step: Any, operator: Any, keyword: str) -> bool:
"""Match if keyword is 'y_processing' and operator is TransformerMixin or list thereof.
Args:
step: Original step configuration
operator: Parsed operator (TransformerMixin instance, class, or list)
keyword: Step keyword
Returns:
True if this controller should handle the step
"""
if keyword != "y_processing":
return False
# Single transformer (instance or class)
if _is_transformer_like(operator):
return True
# List of transformers
if isinstance(operator, (list, tuple)) and len(operator) > 0:
return all(_is_transformer_like(t) for t in operator)
return False
[docs]
@classmethod
def use_multi_source(cls) -> bool:
"""Check if the operator supports multi-source datasets."""
return False # Target processing doesn't depend on multiple sources
[docs]
@classmethod
def supports_prediction_mode(cls) -> bool:
"""Y transformers should not execute during prediction mode."""
return True
[docs]
def execute(
self,
step_info: 'ParsedStep',
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
source: int = -1,
mode: str = "train",
loaded_binaries: Any = None,
prediction_store: Any = None
) -> Tuple[ExecutionContext, List[Any]]:
"""
Execute transformer(s) on dataset targets, fitting on train targets and transforming all targets.
Supports both single transformers and chained transformers (list).
Each transformer is applied sequentially, with proper ancestry tracking.
Args:
step_info: Parsed step containing operator and metadata
dataset: Dataset containing targets to transform
context: Pipeline context with partition information
runtime_context: Runtime context containing infrastructure components
source: Source index (not used for target processing)
mode: Execution mode ("train", "predict", or "explain")
loaded_binaries: Pre-loaded fitted transformers for predict/explain mode
prediction_store: Not used for y_processing
Returns:
Tuple of (updated_context, fitted_transformers_list)
"""
operator = step_info.operator
# Normalize to list and instantiate class types
operators = self._normalize_operators(operator)
# Execute each transformer sequentially
current_context = context
all_artifacts = []
for idx, op in enumerate(operators):
current_context, artifacts = self._execute_single_transformer(
transformer=op,
transformer_index=idx,
dataset=dataset,
context=current_context,
runtime_context=runtime_context,
mode=mode
)
all_artifacts.extend(artifacts)
return current_context, all_artifacts
def _normalize_operators(self, operator: Any) -> List[TransformerMixin]:
"""Normalize operator(s) to a list of instantiated TransformerMixin objects.
Args:
operator: Single transformer, class, or list thereof
Returns:
List of instantiated TransformerMixin objects
"""
# Handle single transformer
if not isinstance(operator, (list, tuple)):
operators = [operator]
else:
operators = list(operator)
# Instantiate class types (e.g., StandardScaler vs StandardScaler())
instantiated = []
for op in operators:
if isinstance(op, type) and issubclass(op, TransformerMixin):
instantiated.append(op())
else:
instantiated.append(op)
return instantiated
def _execute_single_transformer(
self,
transformer: TransformerMixin,
transformer_index: int,
dataset: 'SpectroDataset',
context: ExecutionContext,
runtime_context: 'RuntimeContext',
mode: str
) -> Tuple[ExecutionContext, List[Any]]:
"""Execute a single transformer on dataset targets.
Args:
transformer: The transformer to apply
transformer_index: Index in the chain (for naming)
dataset: Dataset containing targets
context: Current execution context
runtime_context: Runtime context for artifact persistence
mode: Execution mode
Returns:
Tuple of (updated_context, artifacts_list)
"""
from sklearn.base import clone
# Generate unique name for this transformer
operator_name = transformer.__class__.__name__
op_id = runtime_context.next_op()
current_y_processing = context.state.y_processing
new_processing_name = f"{current_y_processing}_{operator_name}{op_id}"
# Artifact name for saving/loading (includes index for ordering in chained transformers)
artifact_name = f"y_{operator_name}_{op_id}"
# Handle prediction/explain mode: load pre-fitted transformer
if mode in ("predict", "explain"):
fitted_transformer = None
# V3: Use artifact_provider for chain-based loading
if runtime_context.artifact_provider is not None:
step_index = runtime_context.step_number
step_artifacts = runtime_context.artifact_provider.get_artifacts_for_step(
step_index,
branch_path=context.selector.branch_path
)
if step_artifacts:
# Create dict for name-based lookup
artifacts_dict = dict(step_artifacts)
# Try exact name match first
fitted_transformer = artifacts_dict.get(artifact_name)
if fitted_transformer is None:
# Fallback: search by class name pattern (handles op counter mismatch)
import re
pattern = re.compile(rf'^y_{re.escape(operator_name)}_(\d+)$')
for key, obj in artifacts_dict.items():
if pattern.match(key):
fitted_transformer = obj
break
if fitted_transformer is None and step_artifacts:
fitted_transformer = self._match_transformer_by_class(
operator_name,
step_artifacts,
transformer_index
)
if fitted_transformer is not None:
dataset._targets.add_processed_targets(
processing_name=new_processing_name,
targets=np.array([]),
ancestor=current_y_processing,
transformer=fitted_transformer,
mode=mode
)
updated_context = context.with_y(new_processing_name)
return updated_context, []
else:
# No pre-fitted transformer found - this shouldn't happen in proper predict mode
raise ValueError(
f"No fitted transformer found for '{artifact_name}' at step {runtime_context.step_number}"
)
# Training mode: fit and transform
# Get train targets for fitting (excluding filtered samples)
train_context = context.with_partition("train")
train_y_selector = dict(train_context.selector)
train_y_selector['y'] = current_y_processing
train_data = dataset.y(train_y_selector, include_excluded=False)
# Get all targets for transformation (INCLUDING excluded samples)
# This is necessary because add_processed_targets expects targets for ALL samples
all_y_selector = dict(context.selector)
all_y_selector['y'] = current_y_processing
all_data = dataset.y(all_y_selector, include_excluded=True)
# Clone, fit, and transform
fitted_transformer = clone(transformer)
fitted_transformer.fit(train_data)
transformed_targets = fitted_transformer.transform(all_data)
# Add processed targets to dataset with proper ancestry
dataset.add_processed_targets(
processing_name=new_processing_name,
targets=transformed_targets,
ancestor_processing=current_y_processing,
transformer=fitted_transformer
)
# Update context to use new processing
updated_context = context.with_y(new_processing_name)
# Persist fitted transformer
artifacts = []
if mode == "train":
artifact = self._persist_y_transformer(
runtime_context=runtime_context,
transformer=fitted_transformer,
name=artifact_name,
context=context
)
artifacts.append(artifact)
return updated_context, artifacts
def _match_transformer_by_class(
self,
class_name: str,
artifacts: List[Tuple[str, Any]],
target_index: int = 0
) -> Optional[Any]:
"""Select the nth transformer whose class matches class_name."""
match_count = 0
for _, obj in artifacts:
if obj.__class__.__name__ == class_name:
if match_count == target_index:
return obj
match_count += 1
return None
def _persist_y_transformer(
self,
runtime_context: 'RuntimeContext',
transformer: Any,
name: str,
context: ExecutionContext
) -> Any:
"""Persist fitted Y transformer using V3 artifact registry.
Uses artifact_registry.register() with V3 chain-based identification
for complete execution path tracking.
Args:
runtime_context: Runtime context with saver/registry instances.
transformer: Fitted transformer to persist.
name: Artifact name for the transformer.
context: Execution context with branch information.
Returns:
ArtifactRecord with V3 chain-based metadata, or None if no registry.
"""
# Use artifact registry (V3 system)
if runtime_context.artifact_registry is not None:
registry = runtime_context.artifact_registry
pipeline_id = runtime_context.saver.pipeline_id if runtime_context.saver else "unknown"
step_index = runtime_context.step_number
branch_path = context.selector.branch_path or []
# Extract the operation counter from the name (e.g., "y_MinMaxScaler_1" -> 1)
substep_index = None
if "_" in name:
try:
substep_index = int(name.rsplit("_", 1)[1])
except (ValueError, IndexError):
pass
# V3: Build operator chain for this artifact
from nirs4all.pipeline.storage.artifacts.operator_chain import OperatorNode, OperatorChain
# Get the current chain from trace recorder or build new one
if runtime_context.trace_recorder is not None:
current_chain = runtime_context.trace_recorder.current_chain()
else:
current_chain = OperatorChain(pipeline_id=pipeline_id)
# Create node for this Y transformer
transformer_node = OperatorNode(
step_index=step_index,
operator_class=f"y_{transformer.__class__.__name__}", # Prefix with y_ to distinguish
branch_path=branch_path,
source_index=None, # Y transformers don't have source index
fold_id=None, # Shared across folds
substep_index=substep_index,
)
# Build chain path for this artifact
artifact_chain = current_chain.append(transformer_node)
chain_path = artifact_chain.to_path()
# Generate V3 artifact ID using chain
artifact_id = registry.generate_id(chain_path, None, pipeline_id)
# Register artifact with registry (use ENCODER type for y transformers)
record = registry.register(
obj=transformer,
artifact_id=artifact_id,
artifact_type=ArtifactType.ENCODER,
format_hint='sklearn',
chain_path=chain_path,
)
# Record artifact in execution trace with V3 chain info
runtime_context.record_step_artifact(
artifact_id=artifact_id,
is_primary=False,
fold_id=None,
chain_path=chain_path,
branch_path=branch_path,
metadata={"class_name": transformer.__class__.__name__, "name": name}
)
return record
# No registry available - skip persistence (for unit tests)
return None