Source code for nirs4all.pipeline.run

"""
Run entity for nirs4all pipeline execution.

A Run represents a complete experiment session that combines:
- One or more Pipeline Templates (or concrete Pipelines)
- One or more Datasets

The Run generates Results for every combination of expanded pipeline
configurations and datasets.

Formula:
    Run = [Pipeline Templates] × [Datasets]
        = [Σ Expanded Pipelines from all Templates] × [All Datasets]
        = Results
"""

import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union


[docs] class RunStatus(Enum): """Run execution status.""" QUEUED = "queued" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" PAUSED = "paused" CANCELLED = "cancelled"
# Valid state transitions for the run state machine VALID_TRANSITIONS = { RunStatus.QUEUED: [RunStatus.RUNNING, RunStatus.CANCELLED], RunStatus.RUNNING: [RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.PAUSED], RunStatus.PAUSED: [RunStatus.RUNNING, RunStatus.CANCELLED], RunStatus.FAILED: [RunStatus.QUEUED], # retry RunStatus.COMPLETED: [], # terminal RunStatus.CANCELLED: [], # terminal } # Metric metadata for proper score comparison METRIC_METADATA = { # Regression metrics "r2": {"higher_is_better": True, "optimal": 1.0, "range": (-float('inf'), 1.0)}, "rmse": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "rmsecv": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "rmsep": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "mae": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "mse": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "mape": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, "rpd": {"higher_is_better": True, "optimal": float('inf'), "range": (0.0, float('inf'))}, "bias": {"higher_is_better": False, "optimal": 0.0, "range": (-float('inf'), float('inf'))}, "sep": {"higher_is_better": False, "optimal": 0.0, "range": (0.0, float('inf'))}, # Classification metrics "accuracy": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "precision": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "recall": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "f1": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "f1_score": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "auc": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, "roc_auc": {"higher_is_better": True, "optimal": 1.0, "range": (0.0, 1.0)}, # Default for unknown metrics "default": {"higher_is_better": True, "optimal": 1.0, "range": (-float('inf'), float('inf'))}, }
[docs] def get_metric_info(metric_name: str) -> Dict[str, Any]: """ Get metadata for a metric. Args: metric_name: Name of the metric (e.g., 'r2', 'rmse', 'accuracy') Returns: Dict with 'higher_is_better', 'optimal', and 'range' keys """ metric_lower = metric_name.lower() return METRIC_METADATA.get(metric_lower, METRIC_METADATA["default"])
[docs] def is_better_score(score: float, best_score: float, metric: str) -> bool: """ Compare two scores and determine if the new score is better. Args: score: New score to compare best_score: Current best score metric: Metric name to determine comparison direction Returns: True if score is better than best_score """ info = get_metric_info(metric) if info["higher_is_better"]: return score > best_score else: return score < best_score
[docs] @dataclass class TemplateInfo: """Information about a pipeline template in a run.""" id: str name: str file_path: Optional[str] = None expansion_count: int = 1 description: Optional[str] = None
[docs] @dataclass class DatasetInfo: """Information about a dataset used in a run.""" name: str path: str hash: Optional[str] = None file_size: Optional[int] = None n_samples: Optional[int] = None n_features: Optional[int] = None task_type: Optional[str] = None y_columns: Optional[List[str]] = None y_stats: Optional[Dict[str, Dict[str, float]]] = None wavelength_range: Optional[List[float]] = None wavelength_unit: Optional[str] = None metadata: Optional[Dict[str, Any]] = None version: Optional[str] = None
[docs] @dataclass class RunConfig: """Configuration for a run.""" cv_folds: int = 5 cv_strategy: str = "kfold" random_state: Optional[int] = 42 metric: str = "r2" save_predictions: bool = True save_models: bool = True
[docs] @dataclass class RunSummary: """Summary of run results.""" total_results: int = 0 completed_results: int = 0 failed_results: int = 0 best_result: Optional[Dict[str, Any]] = None
[docs] @dataclass class Run: """ Represents a complete experiment session. A Run combines pipeline templates with datasets and generates results for every combination of expanded pipeline configurations and datasets. Attributes: id: Unique identifier for the run name: Human-readable name templates: List of pipeline templates datasets: List of datasets status: Current execution status config: Run configuration created_at: Creation timestamp started_at: Execution start timestamp completed_at: Completion timestamp summary: Post-execution summary """ id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) name: str = "" templates: List[TemplateInfo] = field(default_factory=list) datasets: List[DatasetInfo] = field(default_factory=list) status: RunStatus = RunStatus.QUEUED config: RunConfig = field(default_factory=RunConfig) created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) started_at: Optional[str] = None completed_at: Optional[str] = None summary: RunSummary = field(default_factory=RunSummary) checkpoints: List[Dict[str, Any]] = field(default_factory=list) @property def total_pipeline_configs(self) -> int: """Total number of expanded pipeline configurations.""" return sum(t.expansion_count for t in self.templates) @property def total_results_expected(self) -> int: """Expected number of results (configs × datasets).""" return self.total_pipeline_configs * len(self.datasets)
[docs] def can_transition_to(self, new_status: RunStatus) -> bool: """Check if transition to new status is valid.""" return new_status in VALID_TRANSITIONS.get(self.status, [])
[docs] def transition_to(self, new_status: RunStatus) -> None: """ Transition to a new status. Raises: ValueError: If transition is not valid """ if not self.can_transition_to(new_status): raise ValueError( f"Invalid transition from {self.status.value} to {new_status.value}. " f"Valid transitions: {[s.value for s in VALID_TRANSITIONS.get(self.status, [])]}" ) self.status = new_status if new_status == RunStatus.RUNNING and self.started_at is None: self.started_at = datetime.now(timezone.utc).isoformat() elif new_status in (RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.CANCELLED): self.completed_at = datetime.now(timezone.utc).isoformat()
[docs] def add_checkpoint(self, result_id: str, metadata: Optional[Dict[str, Any]] = None) -> None: """Record a completed result as a checkpoint.""" checkpoint = { "result_id": result_id, "completed_at": datetime.now(timezone.utc).isoformat(), } if metadata: checkpoint.update(metadata) self.checkpoints.append(checkpoint)
[docs] def to_dict(self) -> Dict[str, Any]: """Convert run to dictionary for serialization.""" return { "id": self.id, "name": self.name, "templates": [ { "id": t.id, "name": t.name, "file_path": t.file_path, "expansion_count": t.expansion_count, "description": t.description, } for t in self.templates ], "datasets": [ { "name": d.name, "path": d.path, "hash": d.hash, "file_size": d.file_size, "n_samples": d.n_samples, "n_features": d.n_features, "task_type": d.task_type, "y_columns": d.y_columns, "y_stats": d.y_stats, "wavelength_range": d.wavelength_range, "wavelength_unit": d.wavelength_unit, "metadata": d.metadata, "version": d.version, } for d in self.datasets ], "status": self.status.value, "config": { "cv_folds": self.config.cv_folds, "cv_strategy": self.config.cv_strategy, "random_state": self.config.random_state, "metric": self.config.metric, "save_predictions": self.config.save_predictions, "save_models": self.config.save_models, }, "created_at": self.created_at, "started_at": self.started_at, "completed_at": self.completed_at, "total_pipeline_configs": self.total_pipeline_configs, "summary": { "total_results": self.summary.total_results, "completed_results": self.summary.completed_results, "failed_results": self.summary.failed_results, "best_result": self.summary.best_result, }, "checkpoints": self.checkpoints, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Run": """Create run from dictionary.""" run = cls( id=data.get("id", str(uuid.uuid4())[:8]), name=data.get("name", ""), status=RunStatus(data.get("status", "queued")), created_at=data.get("created_at", datetime.now(timezone.utc).isoformat()), started_at=data.get("started_at"), completed_at=data.get("completed_at"), checkpoints=data.get("checkpoints", []), ) # Parse templates for t in data.get("templates", []): run.templates.append(TemplateInfo( id=t.get("id", ""), name=t.get("name", ""), file_path=t.get("file_path"), expansion_count=t.get("expansion_count", 1), description=t.get("description"), )) # Parse datasets for d in data.get("datasets", []): run.datasets.append(DatasetInfo( name=d.get("name", ""), path=d.get("path", ""), hash=d.get("hash"), file_size=d.get("file_size"), n_samples=d.get("n_samples"), n_features=d.get("n_features"), task_type=d.get("task_type"), y_columns=d.get("y_columns"), y_stats=d.get("y_stats"), wavelength_range=d.get("wavelength_range"), wavelength_unit=d.get("wavelength_unit"), metadata=d.get("metadata"), version=d.get("version"), )) # Parse config config_data = data.get("config", {}) run.config = RunConfig( cv_folds=config_data.get("cv_folds", 5), cv_strategy=config_data.get("cv_strategy", "kfold"), random_state=config_data.get("random_state", 42), metric=config_data.get("metric", "r2"), save_predictions=config_data.get("save_predictions", True), save_models=config_data.get("save_models", True), ) # Parse summary summary_data = data.get("summary", {}) run.summary = RunSummary( total_results=summary_data.get("total_results", 0), completed_results=summary_data.get("completed_results", 0), failed_results=summary_data.get("failed_results", 0), best_result=summary_data.get("best_result"), ) return run
[docs] def generate_run_id(name: str = "") -> str: """ Generate a unique run ID. Format: YYYY-MM-DD_<Name>_<hash> Args: name: Optional descriptive name Returns: Unique run ID string """ date_str = datetime.now().strftime("%Y-%m-%d") hash_str = str(uuid.uuid4())[:6] if name: # Sanitize name for use in ID safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in name) safe_name = safe_name[:30] # Limit length return f"{date_str}_{safe_name}_{hash_str}" else: return f"{date_str}_{hash_str}"