Source code for nirs4all.pipeline.storage.io_resolver

"""Prediction target resolver - resolves prediction targets for predict mode.

Note: This is the legacy target resolver for finding predictions by ID.
For the comprehensive Phase 3 prediction source resolver that normalizes
sources to all components needed for replay, see:
    nirs4all.pipeline.resolver.PredictionResolver
"""
from pathlib import Path
from typing import Any, Dict, Optional, Union

from nirs4all.data.predictions import Predictions


[docs] class TargetResolver: """Resolves prediction targets for predict mode. Focused responsibility: Finding and resolving prediction targets by ID. Note: For the comprehensive Phase 3 resolver that handles multiple source types (prediction dict, folder, Run, artifact_id, bundle), see nirs4all.pipeline.resolver.PredictionResolver. """ def __init__(self, workspace_path: Path): """Initialize resolver for a workspace. Args: workspace_path: Root workspace directory """ self.workspace_path = Path(workspace_path)
[docs] def resolve_target( self, prediction_obj: Union[Dict[str, Any], str] ) -> tuple[str, Optional[Dict[str, Any]]]: """Resolve prediction object to config path and model metadata. Args: prediction_obj: Either: - Dict with 'config_path' and optional model metadata - String: config path or prediction ID Returns: Tuple of (config_path, target_model_metadata) Raises: ValueError: If prediction ID not found or invalid input """ if isinstance(prediction_obj, dict): config_path = prediction_obj['config_path'] target_model = prediction_obj if 'model_name' in prediction_obj else None return config_path, target_model elif isinstance(prediction_obj, str): # Check if it's a file path if Path(prediction_obj).exists(): config_path = prediction_obj target_model = None return config_path, target_model # Otherwise, treat as prediction ID target_model = self.find_prediction_by_id(prediction_obj) if not target_model: raise ValueError(f"Prediction ID not found: {prediction_obj}") config_path = target_model['config_path'] return config_path, target_model else: raise ValueError(f"Invalid prediction_obj type: {type(prediction_obj)}")
[docs] def find_prediction_by_id(self, prediction_id: str) -> Optional[Dict[str, Any]]: """Search for a prediction by ID in global predictions databases. Uses direct ID filtering for O(1) lookup per file instead of O(N) iteration. Args: prediction_id: Unique prediction identifier Returns: Prediction metadata dict, or None if not found """ if not self.workspace_path.exists(): return None # Define search paths (workspace root and runs directory) search_paths = [self.workspace_path] if (self.workspace_path / "runs").exists(): search_paths.append(self.workspace_path / "runs") # Search in global prediction databases for path in search_paths: # Try Parquet files first (new format) for predictions_file in path.glob("*.meta.parquet"): if not predictions_file.is_file(): continue try: predictions = Predictions.load_from_file_cls(str(predictions_file)) # Use direct ID filter instead of iterating all predictions pred = predictions.get_prediction_by_id(prediction_id) if pred is not None: return pred except Exception: continue # Fall back to JSON files (legacy format) for predictions_file in path.glob("*.json"): if not predictions_file.is_file(): continue try: predictions = Predictions.load_from_file_cls(str(predictions_file)) # Use direct ID filter instead of iterating all predictions pred = predictions.get_prediction_by_id(prediction_id) if pred is not None: return pred except Exception: continue return None
[docs] def find_best_for_config(self, config_path: str) -> Optional[Dict[str, Any]]: """Find the best prediction for a given config path. Args: config_path: Path to pipeline configuration Returns: Best prediction metadata, or None if not found """ # Extract dataset name from path structure # config_path typically: "runs/YYYY-MM-DD_dataset/0001_hash/pipeline.json" path_parts = Path(config_path).parts for part in path_parts: if '_' in part and not part.startswith('_'): # Likely the dataset directory dataset_name = part.split('_', 1)[1] # Remove date prefix break else: return None # Load predictions for this dataset predictions_file = self.workspace_path / f"{dataset_name}.json" if not predictions_file.exists(): # Try parquet format predictions_file = self.workspace_path / f"{dataset_name}.meta.parquet" if not predictions_file.exists(): return None try: predictions = Predictions.load_from_file_cls(str(predictions_file)) # Filter by config path - use filter_predictions with load_arrays=False first # to avoid loading all arrays when we just need to find the best score matching = predictions.filter_predictions(config_path=config_path, load_arrays=False) if not matching: return None # Return best by score (assumes lower is better) best = min(matching, key=lambda p: p.get('test_score', float('inf'))) # Load arrays for the best prediction only best_id = best.get('id') if best_id: return predictions.get_prediction_by_id(best_id, load_arrays=True) return best except Exception: return None
# Backward compatibility alias # Deprecated: Use nirs4all.pipeline.resolver.PredictionResolver for Phase 3+ features PredictionResolver = TargetResolver