Source code for nirs4all.pipeline.minimal_predictor

"""
Minimal Pipeline Predictor - Execute minimal pipeline for prediction (V3).

This module provides the MinimalPredictor class which executes a minimal
pipeline extracted from an execution trace. It reuses existing controllers
in predict mode with artifact injection.

The MinimalPredictor is the key component of Phase 5: it ensures that
prediction only runs the required steps, not the entire original pipeline.

V3 Features:
    - Chain-based artifact identification using chain_path
    - ArtifactRecord metadata for branch/substep info (no ID parsing)
    - Support for multi-source pipelines via source_index

Design Principles:
    1. Controller-Agnostic: Uses existing controllers without hardcoding types
    2. Minimal Execution: Only runs steps needed for the specific prediction
    3. Artifact Injection: Provides pre-loaded artifacts to controllers
    4. Deterministic: Same minimal pipeline -> same prediction

Usage:
    >>> from nirs4all.pipeline.minimal_predictor import MinimalPredictor
    >>> from nirs4all.pipeline.trace import TraceBasedExtractor
    >>>
    >>> # Extract minimal pipeline
    >>> extractor = TraceBasedExtractor()
    >>> minimal = extractor.extract(trace, full_pipeline_steps)
    >>>
    >>> # Predict using minimal pipeline
    >>> predictor = MinimalPredictor(artifact_loader, run_dir)
    >>> y_pred, predictions = predictor.predict(minimal, dataset)
"""

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from nirs4all.data.dataset import SpectroDataset
from nirs4all.data.predictions import Predictions
from nirs4all.pipeline.config.context import (
    DataSelector,
    ExecutionContext,
    PipelineState,
    RuntimeContext,
    StepMetadata,
    ArtifactProvider,
    LoaderArtifactProvider,
)
from nirs4all.pipeline.storage.artifacts.types import ArtifactRecord, ArtifactType
from nirs4all.pipeline.trace import MinimalPipeline, MinimalPipelineStep


logger = logging.getLogger(__name__)


[docs] class MinimalArtifactProvider(ArtifactProvider): """Artifact provider backed by a MinimalPipeline (V3). Provides artifacts from the minimal pipeline's artifact map, which contains StepArtifacts extracted from the execution trace. This provider uses V3 ArtifactRecord metadata (chain_path, branch_path, substep_index) instead of parsing V2-style artifact IDs. Attributes: minimal_pipeline: The source MinimalPipeline artifact_loader: ArtifactLoader for loading actual artifact objects target_sub_index: Filter artifacts by substep_index target_model_name: Filter artifacts by custom_name """ def __init__( self, minimal_pipeline: MinimalPipeline, artifact_loader: Any, # ArtifactLoader target_sub_index: Optional[int] = None, target_model_name: Optional[str] = None ): """Initialize minimal artifact provider. Args: minimal_pipeline: MinimalPipeline with artifact mappings artifact_loader: ArtifactLoader for loading artifact objects target_sub_index: Optional substep_index to filter model artifacts. Used when a subpipeline contains multiple models (e.g., [JaxMLPRegressor, nicon]) and we need to load artifacts for a specific one. target_model_name: Optional model name to filter artifacts by. Used as fallback when sub_index is not available (e.g., for avg/w_avg predictions). """ self.minimal_pipeline = minimal_pipeline self.artifact_loader = artifact_loader self.target_sub_index = target_sub_index self.target_model_name = target_model_name self._cache: Dict[str, Any] = {} def _get_record(self, artifact_id: str) -> Optional[ArtifactRecord]: """Get artifact record from loader. Args: artifact_id: Artifact ID to look up Returns: ArtifactRecord or None if not found """ try: return self.artifact_loader.get_record(artifact_id) except (KeyError, AttributeError): return None def _get_branch_from_record(self, artifact_id: str) -> Optional[int]: """Get branch index from artifact record. Uses ArtifactRecord.branch_path for V3 artifacts. Args: artifact_id: Artifact ID to look up Returns: First branch index (0, 1, ...) if artifact is from a branch, None if artifact is shared (pre-branch). """ record = self._get_record(artifact_id) if record is None: return None if record.branch_path and len(record.branch_path) > 0: return record.branch_path[0] return None def _get_substep_from_record(self, artifact_id: str) -> Optional[int]: """Get substep_index from artifact record. Uses ArtifactRecord.substep_index for V3 artifacts. Args: artifact_id: Artifact ID to look up Returns: Substep index or None if not present. """ record = self._get_record(artifact_id) if record is None: return None return record.substep_index def _derive_operator_name( self, obj: Any, artifact_id: str, step_index: Optional[int] = None ) -> str: """Derive operator name from object class and artifact metadata. Reconstructs names like "MinMaxScaler_1" from the object's class and the substep_index from the artifact record. For y_processing steps, adds "y_" prefix to match the naming convention used during training (e.g., "y_MinMaxScaler_1"). Args: obj: The loaded artifact object. artifact_id: The artifact ID. step_index: Optional step index for checking if y_processing. Returns: Operator name in format "{ClassName}_{substep_index}" or class name. For y_processing steps: "y_{ClassName}_{substep_index}". """ class_name = obj.__class__.__name__ # Get substep_index from artifact record (V3 approach) sub_index = self._get_substep_from_record(artifact_id) # Check if this is a y_processing step is_y_processing = False if step_index is not None and self.minimal_pipeline: minimal_step = self.minimal_pipeline.get_step(step_index) if minimal_step and minimal_step.operator_type == "y_processing": is_y_processing = True # Build name with optional y_ prefix prefix = "y_" if is_y_processing else "" if sub_index is not None: return f"{prefix}{class_name}_{sub_index}" return f"{prefix}{class_name}"
[docs] def get_artifact( self, step_index: int, fold_id: Optional[int] = None ) -> Optional[Any]: """Get a single artifact for a step. Args: step_index: 1-based step index fold_id: Optional fold ID for fold-specific artifacts Returns: Artifact object or None if not found """ step_artifacts = self.minimal_pipeline.get_artifacts_for_step(step_index) if not step_artifacts: return None # Try fold-specific artifact first if fold_id is not None and step_artifacts.fold_artifact_ids: artifact_id = step_artifacts.fold_artifact_ids.get(fold_id) if artifact_id: return self._load_artifact(artifact_id) # Try primary artifact if step_artifacts.primary_artifact_id: return self._load_artifact(step_artifacts.primary_artifact_id) # Try first artifact if step_artifacts.artifact_ids: return self._load_artifact(step_artifacts.artifact_ids[0]) return None
[docs] def get_artifacts_for_step( self, step_index: int, branch_path: Optional[List[int]] = None, branch_id: Optional[int] = None, source_index: Optional[int] = None, substep_index: Optional[int] = None ) -> List[Tuple[str, Any]]: """Get all artifacts for a step (V3). Filters artifacts by branch using the branch_path from ArtifactRecord. This is critical for multisource + branching reload, where branch substep artifacts are lumped together in the execution trace but can be distinguished by their artifact records. Returns tuples of (operator_name, artifact_object) where operator_name is derived from the object class and substep_index (e.g., "MinMaxScaler_1"). This allows transformer controllers to look up artifacts by name. Args: step_index: 1-based step index branch_path: Optional branch path filter (e.g., [0] for branch 0) branch_id: Optional branch ID filter (used when branch_path not available) source_index: Optional source/dataset index filter for multi-source substep_index: Optional substep index filter for branch substeps Returns: List of (operator_name, artifact_object) tuples """ step_artifacts = self.minimal_pipeline.get_artifacts_for_step(step_index) if not step_artifacts: return [] # Determine target branch to filter by target_branch: Optional[int] = None if branch_path is not None and len(branch_path) > 0: target_branch = branch_path[0] elif branch_id is not None: target_branch = branch_id # Debug: log filtering params logger.debug( f"get_artifacts_for_step({step_index}): target_sub_index={self.target_sub_index}, " f"substep_index={substep_index}, model_step_index={self.minimal_pipeline.model_step_index}" ) results = [] for artifact_id in step_artifacts.artifact_ids: # Get artifact record for V3 metadata record = self._get_record(artifact_id) # Filter by branch if specified - use record.branch_path if target_branch is not None: artifact_branch = None if record and record.branch_path: artifact_branch = record.branch_path[0] if len(record.branch_path) > 0 else None # Include artifact if: # - It has no branch (shared/pre-branch artifact) # - Its branch matches the target branch if artifact_branch is not None and artifact_branch != target_branch: logger.debug( f"Filtering artifact {artifact_id} (branch={artifact_branch}) " f"- target branch is {target_branch}" ) continue # Filter by substep_index if specified (for branch substeps) # This ensures each transformer controller gets only its own artifact if substep_index is not None and record is not None: artifact_substep = record.substep_index if artifact_substep is not None and artifact_substep != substep_index: logger.debug( f"Filtering artifact {artifact_id} (substep_index={artifact_substep}) " f"- target substep is {substep_index}" ) continue # Filter by source_index if specified (for multi-source datasets) # Each source gets its own transformer artifact, and we need to select # the correct one for the current source being processed if source_index is not None and record is not None: artifact_source = record.source_index if artifact_source is not None and artifact_source != source_index: logger.debug( f"Filtering artifact {artifact_id} (source_index={artifact_source}) " f"- target source is {source_index}" ) continue # Filter model artifacts to only load the correct model in subpipelines # This is critical when a list like [JaxMLPRegressor, nicon] creates # artifacts with different substep_index values should_filter = self.target_sub_index is not None or self.target_model_name is not None if should_filter and record is not None: # Check if this artifact is a model by querying its type from the record is_model_artifact = record.artifact_type in ( ArtifactType.MODEL, ArtifactType.META_MODEL ) if is_model_artifact: # Strategy 1: Filter by substep_index if available if self.target_sub_index is not None: artifact_sub_index = record.substep_index if artifact_sub_index is not None and artifact_sub_index != self.target_sub_index: logger.debug( f"Filtering artifact {artifact_id} (substep_index={artifact_sub_index}) " f"- target sub_index is {self.target_sub_index}" ) continue # Strategy 2: Filter by model name (fallback for avg/w_avg predictions) elif self.target_model_name is not None: # Check if artifact's custom_name matches target model name if record.custom_name and record.custom_name != self.target_model_name: logger.debug( f"Filtering artifact {artifact_id} (custom_name={record.custom_name}) " f"- target model is {self.target_model_name}" ) continue obj = self._load_artifact(artifact_id) if obj is not None: # Derive operator name from object class and artifact sub_index # This allows transformer controllers to look up by name # Pass step_index to check if y_processing (needs y_ prefix) operator_name = self._derive_operator_name(obj, artifact_id, step_index) # Get substep_index for sorting artifact_substep = record.substep_index if record else None results.append((operator_name, obj, artifact_substep)) # Sort by substep_index to ensure artifacts are returned in the same order # they were created during training. This is critical for multi-source # pipelines with feature_augmentation where transformers are loaded # by index position. results.sort(key=lambda x: (x[2] if x[2] is not None else float('inf'))) # Remove substep_index from results (keep only operator_name, obj tuples) results = [(name, obj) for name, obj, _ in results] logger.debug( f"get_artifacts_for_step({step_index}, branch_path={branch_path}) " f"-> {len(results)} artifacts from {len(step_artifacts.artifact_ids)} total" ) return results
[docs] def get_fold_artifacts( self, step_index: int, branch_path: Optional[List[int]] = None ) -> List[Tuple[int, Any]]: """Get all fold-specific artifacts for a step. Filters by target_sub_index when set (for subpipelines with multiple models). When target_sub_index is set, looks through all artifact_ids instead of fold_artifact_ids because fold_artifact_ids only stores the last model's artifacts when multiple models exist in a subpipeline. Args: step_index: 1-based step index branch_path: Optional branch path filter Returns: List of (fold_id, artifact_object) tuples, sorted by fold_id """ step_artifacts = self.minimal_pipeline.get_artifacts_for_step(step_index) if not step_artifacts: return [] results = [] # When target_sub_index is set, we need to search through all artifact_ids # because fold_artifact_ids gets overwritten when multiple models exist in a subpipeline if self.target_sub_index is not None: for artifact_id in step_artifacts.artifact_ids: record = self._get_record(artifact_id) if record is None: continue # Check if this is a model artifact with matching substep_index is_model_artifact = record.artifact_type in ( ArtifactType.MODEL, ArtifactType.META_MODEL ) if not is_model_artifact: continue artifact_sub_index = record.substep_index if artifact_sub_index is not None and artifact_sub_index != self.target_sub_index: logger.debug( f"Filtering artifact {artifact_id} (substep_index={artifact_sub_index}) " f"- target sub_index is {self.target_sub_index}" ) continue # Extract fold_id from artifact_id (format: pipeline_uid$hash:fold_id) fold_id = None if ':' in artifact_id: parts = artifact_id.rsplit(':', 1) if len(parts) == 2: try: fold_id = int(parts[1]) except ValueError: pass if fold_id is None: continue obj = self._load_artifact(artifact_id) if obj is not None: results.append((fold_id, obj)) else: # Standard case: use fold_artifact_ids if not step_artifacts.fold_artifact_ids: return [] for fold_id, artifact_id in step_artifacts.fold_artifact_ids.items(): obj = self._load_artifact(artifact_id) if obj is not None: results.append((fold_id, obj)) return sorted(results, key=lambda x: x[0])
[docs] def has_artifacts_for_step(self, step_index: int) -> bool: """Check if artifacts exist for a step. Args: step_index: 1-based step index Returns: True if artifacts are available for this step """ step_artifacts = self.minimal_pipeline.get_artifacts_for_step(step_index) return step_artifacts is not None and len(step_artifacts.artifact_ids) > 0
[docs] def get_fold_weights(self) -> Dict[int, float]: """Get fold weights for CV ensemble averaging. Returns: Dictionary mapping fold_id to weight """ return dict(self.minimal_pipeline.fold_weights or {})
def _load_artifact(self, artifact_id: str) -> Optional[Any]: """Load an artifact by ID with caching. Args: artifact_id: Artifact ID to load Returns: Loaded artifact object or None on error """ if artifact_id in self._cache: return self._cache[artifact_id] try: obj = self.artifact_loader.load_by_id(artifact_id) self._cache[artifact_id] = obj return obj except (KeyError, FileNotFoundError) as e: logger.warning(f"Failed to load artifact {artifact_id}: {e}") return None
[docs] class MinimalPredictor: """Execute minimal pipeline for prediction. This class takes a MinimalPipeline (extracted from an ExecutionTrace) and executes only the required steps using existing controllers with artifact injection. The MinimalPredictor achieves the Phase 5 goal of "execute only needed steps" by: 1. Using the minimal pipeline's step list (not the full original pipeline) 2. Injecting pre-loaded artifacts via ArtifactProvider 3. Running controllers in predict mode Attributes: artifact_loader: ArtifactLoader for loading artifacts run_dir: Path to run directory saver: Optional SimulationSaver for outputs manifest_manager: Optional ManifestManager verbose: Verbosity level Example: >>> predictor = MinimalPredictor(artifact_loader, run_dir) >>> y_pred, predictions = predictor.predict(minimal_pipeline, dataset) """ def __init__( self, artifact_loader: Any, # ArtifactLoader run_dir: Union[str, Path], saver: Any = None, manifest_manager: Any = None, verbose: int = 0 ): """Initialize minimal predictor. Args: artifact_loader: ArtifactLoader for loading artifacts run_dir: Path to run directory saver: Optional SimulationSaver for outputs manifest_manager: Optional ManifestManager verbose: Verbosity level """ self.artifact_loader = artifact_loader self.run_dir = Path(run_dir) self.saver = saver self.manifest_manager = manifest_manager self.verbose = verbose def _get_substep_from_artifact(self, artifact_id: str) -> Optional[int]: """Get substep_index from artifact record (V3). Args: artifact_id: Artifact ID to look up. Returns: substep_index or None if not found. """ try: record = self.artifact_loader.get_record(artifact_id) if record is not None: return record.substep_index except (KeyError, AttributeError): pass return None
[docs] def predict( self, minimal_pipeline: MinimalPipeline, dataset: SpectroDataset, target_model: Optional[Dict[str, Any]] = None ) -> Tuple[np.ndarray, Predictions]: """Execute minimal pipeline and return predictions. Runs only the steps in the minimal pipeline, using pre-loaded artifacts from the execution trace. Args: minimal_pipeline: MinimalPipeline to execute dataset: Dataset to predict on target_model: Optional target model metadata for filtering Returns: Tuple of (y_pred array, Predictions object) """ from nirs4all.pipeline.execution.builder import ExecutorBuilder logger.info(f"Minimal prediction: {minimal_pipeline.get_step_count()} steps") # Extract target_sub_index from model_artifact_id if present # This is critical for subpipelines with multiple models target_sub_index = None if target_model: model_artifact_id = target_model.get('model_artifact_id') if model_artifact_id: target_sub_index = self._get_substep_from_artifact(model_artifact_id) # Create artifact provider from minimal pipeline artifact_provider = MinimalArtifactProvider( minimal_pipeline=minimal_pipeline, artifact_loader=self.artifact_loader, target_sub_index=target_sub_index ) # Initialize context for prediction context = ExecutionContext( selector=DataSelector( partition="all", processing=[["raw"]] * dataset.features_sources(), layout="2d", concat_source=True ), state=PipelineState( y_processing="numeric", step_number=0, mode="predict" ), metadata=StepMetadata() ) # Build executor executor = (ExecutorBuilder() .with_run_directory(self.run_dir) .with_verbose(self.verbose) .with_mode("predict") .with_save_artifacts(False) .with_save_charts(False) .with_continue_on_error(False) .with_show_spinner(False) .with_plots_visible(False) .with_artifact_loader(self.artifact_loader) .with_saver(self.saver) .with_manifest_manager(self.manifest_manager) .build()) # Create RuntimeContext with artifact_provider runtime_context = RuntimeContext( saver=self.saver, manifest_manager=self.manifest_manager, artifact_loader=self.artifact_loader, artifact_provider=artifact_provider, step_runner=executor.step_runner, target_model=target_model ) # Extract step configs from minimal pipeline steps = [step.step_config for step in minimal_pipeline.steps] # Execute minimal pipeline predictions = Predictions() executor.execute_minimal( steps=steps, minimal_pipeline=minimal_pipeline, dataset=dataset, context=context, runtime_context=runtime_context, prediction_store=predictions ) # Get y_pred from predictions if predictions.num_predictions > 0: # Filter by target model if specified if target_model: candidates = predictions.filter_predictions(**{ k: v for k, v in target_model.items() if k in ("model_name", "step_idx", "fold_id", "branch_id") }) else: candidates = predictions.to_dicts() # Get non-empty predictions non_empty = [p for p in candidates if len(p.get("y_pred", [])) > 0] if non_empty: y_pred = non_empty[0]["y_pred"] logger.success(f"Prediction complete: {len(y_pred)} samples") return np.array(y_pred), predictions # Return empty if no predictions return np.array([]), predictions
[docs] def predict_with_fold_ensemble( self, minimal_pipeline: MinimalPipeline, dataset: SpectroDataset, fold_strategy: str = "weighted_average" ) -> Tuple[np.ndarray, Predictions]: """Execute minimal pipeline with fold ensemble averaging. For cross-validation models, runs prediction with each fold model and combines results according to fold_strategy. Args: minimal_pipeline: MinimalPipeline to execute dataset: Dataset to predict on fold_strategy: How to combine folds ("average", "weighted_average") Returns: Tuple of (y_pred array, Predictions object) """ fold_weights = minimal_pipeline.fold_weights or {} if not fold_weights: # No folds, regular prediction return self.predict(minimal_pipeline, dataset) # Get predictions for each fold fold_predictions: Dict[int, np.ndarray] = {} for fold_id in sorted(fold_weights.keys()): target_model = {"fold_id": fold_id} y_pred, _ = self.predict(minimal_pipeline, dataset, target_model) if len(y_pred) > 0: fold_predictions[fold_id] = y_pred if not fold_predictions: return np.array([]), Predictions() # Combine fold predictions fold_arrays = list(fold_predictions.values()) fold_ids = list(fold_predictions.keys()) if fold_strategy == "weighted_average" and fold_weights: # Weighted average weights = np.array([fold_weights.get(fid, 1.0) for fid in fold_ids]) weights = weights / weights.sum() # Normalize y_pred_combined = np.average(fold_arrays, axis=0, weights=weights) else: # Simple average y_pred_combined = np.mean(fold_arrays, axis=0) # Create combined prediction record predictions = Predictions() predictions.add_prediction( dataset_name="prediction", model_name=minimal_pipeline.preprocessing_chain, step_idx=minimal_pipeline.model_step_index or 0, fold_id="ensemble", y_pred=y_pred_combined ) return y_pred_combined, predictions
[docs] def validate_minimal_pipeline( self, minimal_pipeline: MinimalPipeline ) -> Tuple[bool, List[str]]: """Validate that minimal pipeline can be executed. Checks that: - All step configs are present - All required artifacts are loadable - Model step is included Args: minimal_pipeline: MinimalPipeline to validate Returns: Tuple of (is_valid, list of issues) """ issues = [] # Check step configs for step in minimal_pipeline.steps: if step.step_config is None: issues.append(f"Step {step.step_index} has no config") # Check model step if minimal_pipeline.model_step_index is None: issues.append("No model step in minimal pipeline") elif not minimal_pipeline.has_step(minimal_pipeline.model_step_index): issues.append( f"Model step {minimal_pipeline.model_step_index} not in pipeline" ) # Check artifacts are loadable for step_index, step_artifacts in minimal_pipeline.artifact_map.items(): for artifact_id in step_artifacts.artifact_ids: try: self.artifact_loader.load_by_id(artifact_id) except (KeyError, FileNotFoundError): issues.append( f"Artifact {artifact_id} for step {step_index} not loadable" ) is_valid = len(issues) == 0 return is_valid, issues