Source code for nirs4all.pipeline.steps.step_runner

"""Step runner for executing individual pipeline steps."""
from typing import Any, List, Optional, Tuple

from nirs4all.data.dataset import SpectroDataset
from nirs4all.data.predictions import Predictions
from nirs4all.pipeline.config.context import ExecutionContext
from nirs4all.core.logging import get_logger
from nirs4all.pipeline.execution.result import ArtifactMeta, StepResult
from nirs4all.pipeline.steps.parser import ParsedStep, StepParser, StepType
from nirs4all.pipeline.steps.router import ControllerRouter

logger = get_logger(__name__)


[docs] class StepRunner: """Executes a single pipeline step. Handles: - Step parsing (delegates to StepParser) - Controller selection (delegates to ControllerRouter) - Controller execution - Binary loading/saving for this step Attributes: parser: Parses step configuration router: Routes to appropriate controller verbose: Verbosity level mode: Execution mode (train/predict/explain) """ def __init__( self, parser: Optional[StepParser] = None, router: Optional[ControllerRouter] = None, verbose: int = 0, mode: str = "train", show_spinner: bool = True, plots_visible: bool = False ): """Initialize step runner. Args: parser: Step parser (creates new if None) router: Controller router (creates new if None) verbose: Verbosity level mode: Execution mode (train/predict/explain) show_spinner: Whether to show spinner for long operations plots_visible: Whether to display plots """ self.parser = parser or StepParser() self.router = router or ControllerRouter() self.verbose = verbose self.mode = mode self.show_spinner = show_spinner self.plots_visible = plots_visible self._figure_refs = []
[docs] def execute( self, step: Any, dataset: SpectroDataset, context: ExecutionContext, runtime_context: Any, # RuntimeContext loaded_binaries: Optional[List[Tuple[str, Any]]] = None, prediction_store: Optional[Predictions] = None ) -> StepResult: """Execute a single pipeline step. Args: step: Raw step configuration dataset: Dataset to process context: Execution context runtime_context: Runtime infrastructure context loaded_binaries: Pre-loaded artifacts for this step prediction_store: Prediction store for accumulating results Returns: StepResult with updated context and artifacts Raises: RuntimeError: If step execution fails """ # Parse the step parsed_step = self.parser.parse(step) # Handle None/skip steps (e.g., from _or_: [None, SNV(), ...] generator syntax) if parsed_step.metadata.get("skip", False): step_num = context.state.step_number if context and context.state else "?" logger.debug(f"Step {step_num}: None operator selected by generator, skipping (this is normal for _or_/_cartesian_ with None option).") return StepResult(updated_context=context, artifacts=[]) # Handle subpipelines (nested lists) if parsed_step.step_type == StepType.SUBPIPELINE: substeps = parsed_step.metadata["steps"] current_context = context all_artifacts = [] # In predict mode, check if we should only execute a specific substep # This is critical for subpipelines with multiple models like [JaxMLPRegressor, nicon] # where we want to run only the model that was selected as best during training # # IMPORTANT: Only apply target_sub_index filtering for MODEL subpipelines. # For TRANSFORMER subpipelines (e.g., from feature_augmentation like [SNV, SG]), # all substeps must execute because they represent parallel feature channels. target_sub_index = None if (self.mode in ("predict", "explain") and runtime_context and hasattr(runtime_context, 'artifact_provider') and runtime_context.artifact_provider is not None and hasattr(runtime_context.artifact_provider, 'target_sub_index')): # Only use target_sub_index filtering if this subpipeline contains models # Check if any substep is a model by parsing and routing has_model_substep = self._subpipeline_contains_model(substeps) if has_model_substep: target_sub_index = runtime_context.artifact_provider.target_sub_index # Track substep index for artifact ID uniqueness for substep_idx, substep in enumerate(substeps): # In predict mode with target_sub_index, skip substeps that don't match if target_sub_index is not None and substep_idx != target_sub_index: logger.debug(f"Skipping substep {substep_idx} (target is {target_sub_index})") continue # Update runtime_context substep_number for each substep if runtime_context: runtime_context.substep_number = substep_idx result = self.execute( step=substep, dataset=dataset, context=current_context, runtime_context=runtime_context, loaded_binaries=loaded_binaries, prediction_store=prediction_store ) current_context = result.updated_context all_artifacts.extend(result.artifacts) # Reset substep_number after processing subpipeline if runtime_context: runtime_context.substep_number = -1 return StepResult(updated_context=current_context, artifacts=all_artifacts) # Route to controller controller = self.router.route(parsed_step, step) operator_name = ( parsed_step.operator.__class__.__name__ if parsed_step.operator is not None else "" ) controller_name = controller.__class__.__name__ if parsed_step.operator is not None: logger.debug(f"Executing controller {controller_name} with operator {operator_name}") else: logger.debug(f"Executing controller {controller_name} without operator") # Check if controller supports prediction mode if (self.mode == "predict" or self.mode == "explain") and not controller.supports_prediction_mode(): logger.warning( f"Controller {controller.__class__.__name__} " f"does not support prediction mode, skipping step" ) return StepResult(updated_context=context, artifacts=[]) # Update context with step metadata if parsed_step.keyword: context = context.with_metadata(keyword=parsed_step.keyword) # Execute controller try: result = controller.execute( step_info=parsed_step, dataset=dataset, context=context, runtime_context=runtime_context, source=-1, mode=self.mode, loaded_binaries=loaded_binaries, prediction_store=prediction_store ) # Handle both legacy (context, artifacts) and new (context, StepOutput) returns if isinstance(result, tuple): updated_context, output_data = result # Check if output_data is StepOutput or list of artifacts from nirs4all.pipeline.execution.result import StepOutput if isinstance(output_data, StepOutput): return StepResult( updated_context=updated_context, artifacts=output_data.artifacts, outputs=output_data.outputs ) else: # Legacy format: output_data is list of artifacts return StepResult( updated_context=updated_context, artifacts=output_data or [], outputs=[] ) # In legacy, artifacts were what? # In BaseModelController: # artifact = self._persist_model(...) -> returns ArtifactMeta # binaries.append(artifact) # So legacy controllers DID persistence and returned ArtifactMeta. # The new proposal says controllers return raw objects. # So if I get StepOutput, it has raw objects. # I need to wrap them in something that StepResult accepts, OR change StepResult. # If I change StepResult to accept raw objects, I break compatibility with legacy controllers that return ArtifactMeta. # Unless I handle both. # Let's make StepResult generic or capable of holding both. # Actually, the Executor will handle persistence. # So StepRunner should just pass the raw objects to Executor. # But StepResult.artifacts is typed as List[ArtifactMeta]. # I should probably update StepResult to allow Any for artifacts, or create a new field. # Or, I can persist here in StepRunner if I have access to artifact_manager? # StepRunner doesn't have artifact_manager. Executor has it. # So StepRunner should return the raw StepOutput to Executor. # I will modify StepResult to carry the StepOutput object. pass return StepResult( updated_context=updated_context, artifacts=output_data.artifacts if isinstance(output_data, StepOutput) else (output_data or []), outputs=output_data.outputs if isinstance(output_data, StepOutput) else [] ) except Exception as e: raise RuntimeError(f"Step execution failed: {str(e)}") from e finally: # Reset ephemeral metadata flags to prevent leakage between steps context.metadata.reset_ephemeral_flags()
def _subpipeline_contains_model(self, substeps: list) -> bool: """Check if a subpipeline contains any model substeps. This is used to determine whether target_sub_index filtering should be applied during prediction. Model subpipelines (e.g., [model1, model2]) need filtering to run only the best model. Transformer subpipelines (e.g., [SNV, SavGol] from feature_augmentation) need to run all substeps. Args: substeps: List of substep configurations Returns: True if any substep is a model, False otherwise """ from nirs4all.controllers.models.base_model import BaseModelController for substep in substeps: parsed = self.parser.parse(substep) if parsed.metadata.get("skip", False): continue # Check if this substep would route to a model controller try: controller = self.router.route(parsed, substep) if isinstance(controller, BaseModelController): return True except Exception: # If routing fails, assume it's not a model pass return False