"""
Result classes for nirs4all API.
These dataclasses wrap the outputs from pipeline execution, prediction,
and explanation operations, providing convenient accessor methods.
Classes:
RunResult: Result from nirs4all.run()
PredictResult: Result from nirs4all.predict()
ExplainResult: Result from nirs4all.explain()
Phase 1 Implementation (v0.6.0):
- RunResult: Full implementation with best, best_score, top(), export()
- PredictResult: Full implementation with values, to_dataframe()
- ExplainResult: Full implementation with values, feature attributions
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from pathlib import Path
import numpy as np
if TYPE_CHECKING:
from nirs4all.pipeline import PipelineRunner
from nirs4all.data.predictions import Predictions
[docs]
@dataclass
class RunResult:
"""Result from nirs4all.run().
Provides convenient access to predictions, best model, and artifacts.
Wraps the raw (predictions, per_dataset) tuple returned by PipelineRunner.run().
Attributes:
predictions: Predictions object containing all pipeline results.
per_dataset: Dictionary with per-dataset execution details.
Properties:
best: Best prediction entry by default ranking.
best_score: Best model's primary test score.
best_rmse: Best model's RMSE (regression).
best_r2: Best model's R² (regression).
best_accuracy: Best model's accuracy (classification).
artifacts_path: Path to run artifacts directory.
num_predictions: Total number of predictions stored.
Methods:
top(n): Get top N predictions by ranking.
export(path): Export best model to .n4a bundle.
filter(**kwargs): Filter predictions by criteria.
get_datasets(): Get list of unique dataset names.
get_models(): Get list of unique model names.
Example:
>>> result = nirs4all.run(pipeline, dataset)
>>> print(f"Best RMSE: {result.best_rmse:.4f}")
>>> print(f"Best R²: {result.best_r2:.4f}")
>>> result.export("exports/best_model.n4a")
"""
predictions: "Predictions"
per_dataset: Dict[str, Any]
_runner: Optional["PipelineRunner"] = field(default=None, repr=False)
# --- Primary accessors ---
@property
def best(self) -> Dict[str, Any]:
"""Get best prediction entry by default ranking.
Returns:
Dictionary containing best model's metrics, name, and configuration.
Empty dict if no predictions available.
"""
top = self.predictions.top(n=1)
return top[0] if top else {}
@property
def best_score(self) -> float:
"""Get best model's primary test score.
Returns:
The test_score value from best prediction, or NaN if unavailable.
"""
return self.best.get('test_score', float('nan'))
@property
def best_rmse(self) -> float:
"""Get best model's RMSE score.
Looks for 'rmse' in scores dict, then falls back to computing from y arrays.
Returns:
RMSE value or NaN if unavailable.
"""
best = self.best
if not best:
return float('nan')
# Try scores dict first
scores = best.get('scores', {})
if isinstance(scores, dict):
test_scores = scores.get('test', {})
if 'rmse' in test_scores:
return test_scores['rmse']
# Fall back to test_score if metric is rmse-like
metric = best.get('metric', '')
if metric in ('rmse', 'mse'):
return best.get('test_score', float('nan'))
return float('nan')
@property
def best_r2(self) -> float:
"""Get best model's R² score.
Looks for 'r2' in scores dict.
Returns:
R² value or NaN if unavailable.
"""
best = self.best
if not best:
return float('nan')
scores = best.get('scores', {})
if isinstance(scores, dict):
test_scores = scores.get('test', {})
if 'r2' in test_scores:
return test_scores['r2']
return float('nan')
@property
def best_accuracy(self) -> float:
"""Get best model's accuracy score (for classification).
Returns:
Accuracy value or NaN if unavailable.
"""
best = self.best
if not best:
return float('nan')
scores = best.get('scores', {})
if isinstance(scores, dict):
test_scores = scores.get('test', {})
if 'accuracy' in test_scores:
return test_scores['accuracy']
# Fall back to test_score if metric is accuracy
metric = best.get('metric', '')
if metric == 'accuracy':
return best.get('test_score', float('nan'))
return float('nan')
# --- Metadata accessors ---
@property
def artifacts_path(self) -> Optional[Path]:
"""Get path to run artifacts directory.
Returns:
Path to the current run directory, or None if not available.
"""
if self._runner and hasattr(self._runner, 'current_run_dir'):
return self._runner.current_run_dir
return None
@property
def num_predictions(self) -> int:
"""Get total number of predictions stored.
Returns:
Number of prediction entries.
"""
return self.predictions.num_predictions
# --- Query methods ---
[docs]
def top(self, n: int = 5, **kwargs) -> Union[List[Dict[str, Any]], Dict[tuple, List[Dict[str, Any]]]]:
"""Get top N predictions by ranking.
Args:
n: Number of top predictions to return. When group_by is used,
this means top N **per group** (e.g., top 3 per dataset).
**kwargs: Additional arguments passed to predictions.top().
Supported kwargs include:
- rank_metric: Metric to rank by (default: uses record's metric)
- rank_partition: Partition to rank on (default: "val")
- display_partition: Partition for display metrics (default: "test")
- aggregate_partitions: If True, include train/val/test data
- ascending: Sort order (None = infer from metric)
- group_by: Group predictions by column(s). Returns top N per group.
Each result includes 'group_key' for easy filtering.
- return_grouped: If True with group_by, return dict of group->results
instead of flat list. Default: False.
Returns:
- If return_grouped=False (default): List of prediction dicts,
ranked by score. With group_by, returns top N per group as flat list.
- If return_grouped=True: Dict mapping group keys to lists of predictions.
Examples:
>>> # Top 5 overall
>>> result.top(5)
>>>
>>> # Top 3 per dataset (flat list)
>>> top_per_ds = result.top(3, group_by='dataset_name')
>>> ds1 = [r for r in top_per_ds if r['group_key'] == ('my_dataset',)]
>>>
>>> # Top 3 per dataset (grouped dict)
>>> grouped = result.top(3, group_by='dataset_name', return_grouped=True)
>>> for key, results in grouped.items():
... print(f"{key}: {len(results)} results")
>>>
>>> # Multi-column grouping: top 2 per (dataset, model) combination
>>> top_per_combo = result.top(2, group_by=['dataset_name', 'model_name'])
>>> # Group keys are tuples: ('wheat', 'PLSRegression'), ('corn', 'RandomForest')
>>> for r in top_per_combo:
... dataset, model = r['group_key']
... print(f"{dataset}/{model}: {r['test_score']:.4f}")
"""
return self.predictions.top(n=n, **kwargs)
[docs]
def filter(self, **kwargs) -> List[Dict[str, Any]]:
"""Filter predictions by criteria.
Args:
**kwargs: Filter criteria passed to predictions.filter_predictions().
Supported kwargs include:
- dataset_name: Filter by dataset name
- model_name: Filter by model name
- partition: Filter by partition ('train', 'val', 'test')
- fold_id: Filter by fold ID
- step_idx: Filter by pipeline step index
- branch_id: Filter by branch ID
- load_arrays: If True, load actual arrays (default: True)
Returns:
List of matching prediction dictionaries.
"""
return self.predictions.filter_predictions(**kwargs)
[docs]
def get_datasets(self) -> List[str]:
"""Get list of unique dataset names.
Returns:
List of dataset names in predictions.
"""
return self.predictions.get_datasets()
[docs]
def get_models(self) -> List[str]:
"""Get list of unique model names.
Returns:
List of model names in predictions.
"""
return self.predictions.get_models()
# --- Export methods ---
[docs]
def export(
self,
output_path: Union[str, Path],
format: str = "n4a",
source: Optional[Dict[str, Any]] = None
) -> Path:
"""Export a model to bundle.
Args:
output_path: Path for the exported bundle file.
format: Export format ('n4a' or 'n4a.py').
source: Prediction dict to export. If None, exports best model.
Returns:
Path to the exported bundle file.
Raises:
RuntimeError: If runner reference is not available.
ValueError: If no predictions available and source not provided.
"""
if self._runner is None:
raise RuntimeError("Cannot export: runner reference not available")
if source is None:
source = self.best
if not source:
raise ValueError("No predictions available to export")
return self._runner.export(
source=source,
output_path=output_path,
format=format
)
[docs]
def export_model(
self,
output_path: Union[str, Path],
source: Optional[Dict[str, Any]] = None,
format: Optional[str] = None,
fold: Optional[int] = None
) -> Path:
"""Export only the model artifact (lightweight).
Unlike export() which creates a full bundle, this exports just the model.
Args:
output_path: Path for the output model file.
source: Prediction dict to export. If None, exports best model.
format: Model format (inferred from extension if None).
fold: Fold index to export (default: fold 0).
Returns:
Path to the exported model file.
Raises:
RuntimeError: If runner reference is not available.
"""
if self._runner is None:
raise RuntimeError("Cannot export: runner reference not available")
if source is None:
source = self.best
if not source:
raise ValueError("No predictions available to export")
return self._runner.export_model(
source=source,
output_path=output_path,
format=format,
fold=fold
)
# --- Utility methods ---
[docs]
def summary(self) -> str:
"""Get a summary string of the run result.
Returns:
Multi-line summary string with key metrics.
"""
lines = []
lines.append(f"RunResult: {self.num_predictions} predictions")
if self.artifacts_path:
lines.append(f" Artifacts: {self.artifacts_path}")
datasets = self.get_datasets()
if datasets:
lines.append(f" Datasets: {', '.join(datasets)}")
models = self.get_models()
if models:
lines.append(f" Models: {', '.join(models[:5])}" +
(f" (+{len(models)-5} more)" if len(models) > 5 else ""))
best = self.best
if best:
lines.append(f" Best: {best.get('model_name', 'unknown')}")
lines.append(f" test_score: {self.best_score:.4f}")
if not np.isnan(self.best_rmse):
lines.append(f" rmse: {self.best_rmse:.4f}")
if not np.isnan(self.best_r2):
lines.append(f" r2: {self.best_r2:.4f}")
return "\n".join(lines)
[docs]
def __repr__(self) -> str:
"""String representation."""
return f"RunResult(predictions={self.num_predictions}, best_score={self.best_score:.4f})"
[docs]
def __str__(self) -> str:
"""User-friendly string representation."""
return self.summary()
[docs]
def validate(
self,
check_nan_metrics: bool = True,
check_empty: bool = True,
raise_on_failure: bool = True,
nan_threshold: float = 0.0
) -> Dict[str, Any]:
"""Validate the run result for common issues.
Checks for NaN values in metrics, empty predictions, and other issues
that might indicate problems with the pipeline execution.
Args:
check_nan_metrics: If True, check for NaN values in metrics.
check_empty: If True, check for empty predictions.
raise_on_failure: If True, raise ValueError on validation failure.
nan_threshold: Maximum allowed ratio of predictions with NaN metrics (0.0 = none allowed).
Returns:
Dictionary with validation results:
- valid: True if all checks passed.
- issues: List of issue descriptions.
- nan_count: Number of predictions with NaN metrics.
- total_count: Total number of predictions.
Raises:
ValueError: If raise_on_failure=True and validation fails.
Example:
>>> result = nirs4all.run(pipeline, dataset)
>>> result.validate() # Raises if issues found
>>> # Or check without raising
>>> report = result.validate(raise_on_failure=False)
>>> if not report['valid']:
... print(f"Issues: {report['issues']}")
"""
issues = []
nan_count = 0
total_count = self.num_predictions
# Check for empty predictions
if check_empty and total_count == 0:
issues.append("No predictions found")
# Check for NaN metrics
if check_nan_metrics and total_count > 0:
all_preds = self.predictions.top(n=total_count)
for pred in all_preds:
has_nan = False
# Check common metrics
for metric in ['rmse', 'r2', 'accuracy', 'mse', 'mae']:
value = pred.get(metric)
if value is not None and np.isnan(value):
has_nan = True
break
# Check scores dict
if not has_nan:
scores = pred.get('scores', {})
if isinstance(scores, dict):
for partition_scores in scores.values():
if isinstance(partition_scores, dict):
for val in partition_scores.values():
if isinstance(val, (int, float)) and np.isnan(val):
has_nan = True
break
# Check test_score
if not has_nan:
test_score = pred.get('test_score')
if test_score is not None and np.isnan(test_score):
has_nan = True
if has_nan:
nan_count += 1
model_name = pred.get('model_name', 'unknown')
if nan_count <= 5: # Only report first 5
issues.append(f"NaN metrics found in prediction: {model_name}")
if nan_count > 5:
issues.append(f"... and {nan_count - 5} more predictions with NaN metrics")
# Check threshold
nan_ratio = nan_count / total_count if total_count > 0 else 0
if nan_ratio > nan_threshold:
issues.append(
f"NaN ratio ({nan_ratio:.1%}) exceeds threshold ({nan_threshold:.1%})"
)
valid = len(issues) == 0
report = {
'valid': valid,
'issues': issues,
'nan_count': nan_count,
'total_count': total_count,
}
if raise_on_failure and not valid:
raise ValueError(
f"RunResult validation failed:\n" +
"\n".join(f" - {issue}" for issue in issues)
)
return report
[docs]
@dataclass
class PredictResult:
"""Result from nirs4all.predict().
Wraps prediction outputs with convenient accessors and conversion methods.
Attributes:
y_pred: Predicted values array (n_samples,) or (n_samples, n_outputs).
metadata: Additional prediction metadata (uncertainty, timing, etc.).
sample_indices: Optional indices of predicted samples.
model_name: Name of the model used for prediction.
preprocessing_steps: List of preprocessing steps applied.
Properties:
values: Alias for y_pred (for consistency).
shape: Shape of prediction array.
is_multioutput: True if predictions have multiple outputs.
Methods:
to_numpy(): Get predictions as numpy array.
to_list(): Get predictions as Python list.
to_dataframe(): Get predictions as pandas DataFrame.
flatten(): Get flattened 1D predictions.
Example:
>>> result = nirs4all.predict(model, X_new)
>>> print(f"Predictions shape: {result.shape}")
>>> df = result.to_dataframe()
"""
y_pred: np.ndarray
metadata: Dict[str, Any] = field(default_factory=dict)
sample_indices: Optional[np.ndarray] = None
model_name: str = ""
preprocessing_steps: List[str] = field(default_factory=list)
[docs]
def __post_init__(self):
"""Ensure y_pred is a numpy array."""
if self.y_pred is not None and not isinstance(self.y_pred, np.ndarray):
self.y_pred = np.asarray(self.y_pred)
@property
def values(self) -> np.ndarray:
"""Get prediction values (alias for y_pred)."""
return self.y_pred
@property
def shape(self) -> tuple:
"""Get shape of prediction array."""
if self.y_pred is None:
return (0,)
return self.y_pred.shape
@property
def is_multioutput(self) -> bool:
"""Check if predictions have multiple outputs."""
return len(self.shape) > 1 and self.shape[1] > 1
[docs]
def __len__(self) -> int:
"""Return number of predictions."""
if self.y_pred is None:
return 0
return len(self.y_pred)
[docs]
def to_numpy(self) -> np.ndarray:
"""Get predictions as numpy array.
Returns:
Numpy array of predictions.
"""
return self.y_pred
[docs]
def to_list(self) -> List[float]:
"""Get predictions as Python list.
Returns:
List of prediction values (flattened if 2D).
"""
if self.y_pred is None:
return []
return self.y_pred.flatten().tolist()
[docs]
def to_dataframe(self, include_indices: bool = True):
"""Get predictions as pandas DataFrame.
Args:
include_indices: If True and sample_indices available, include as column.
Returns:
pandas DataFrame with predictions.
Raises:
ImportError: If pandas is not available.
"""
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required for to_dataframe()")
data = {}
if include_indices and self.sample_indices is not None:
data['sample_index'] = self.sample_indices
if self.is_multioutput:
for i in range(self.shape[1]):
data[f'y_pred_{i}'] = self.y_pred[:, i]
else:
data['y_pred'] = self.y_pred.flatten()
return pd.DataFrame(data)
[docs]
def flatten(self) -> np.ndarray:
"""Get flattened 1D predictions.
Returns:
1D numpy array of predictions.
"""
if self.y_pred is None:
return np.array([])
return self.y_pred.flatten()
[docs]
def __repr__(self) -> str:
"""String representation."""
return f"PredictResult(shape={self.shape}, model='{self.model_name}')"
[docs]
def __str__(self) -> str:
"""User-friendly string representation."""
lines = [f"PredictResult: {len(self)} predictions"]
if self.model_name:
lines.append(f" Model: {self.model_name}")
if self.preprocessing_steps:
lines.append(f" Preprocessing: {' -> '.join(self.preprocessing_steps)}")
lines.append(f" Shape: {self.shape}")
return "\n".join(lines)
[docs]
@dataclass
class ExplainResult:
"""Result from nirs4all.explain().
Wraps SHAP explanation outputs with visualization helpers and accessors.
Attributes:
shap_values: SHAP values array or Explanation object.
feature_names: Names/labels of features explained.
base_value: Expected value (baseline prediction).
visualizations: Paths to generated visualization files.
explainer_type: Type of SHAP explainer used.
model_name: Name of the explained model.
n_samples: Number of samples explained.
Properties:
values: Raw SHAP values array.
shape: Shape of SHAP values array.
mean_abs_shap: Mean absolute SHAP values per feature.
top_features: Feature names sorted by importance.
Methods:
get_feature_importance(): Get feature importance ranking.
get_sample_explanation(idx): Get explanation for a single sample.
to_dataframe(): Get SHAP values as DataFrame.
Example:
>>> result = nirs4all.explain(model, X_test)
>>> print(f"Top features: {result.top_features[:5]}")
>>> importance = result.get_feature_importance()
"""
shap_values: Any # shap.Explanation or np.ndarray
feature_names: Optional[List[str]] = None
base_value: Optional[Union[float, np.ndarray]] = None
visualizations: Dict[str, Path] = field(default_factory=dict)
explainer_type: str = "auto"
model_name: str = ""
n_samples: int = 0
[docs]
def __post_init__(self):
"""Extract metadata from shap_values if available."""
if hasattr(self.shap_values, 'values'):
# It's a shap.Explanation object
if self.feature_names is None and hasattr(self.shap_values, 'feature_names'):
self.feature_names = list(self.shap_values.feature_names)
if self.base_value is None and hasattr(self.shap_values, 'base_values'):
self.base_value = self.shap_values.base_values
if self.n_samples == 0:
self.n_samples = len(self.shap_values.values)
@property
def values(self) -> np.ndarray:
"""Get raw SHAP values array.
Returns:
Numpy array of SHAP values (n_samples, n_features).
"""
if hasattr(self.shap_values, 'values'):
return self.shap_values.values
return np.asarray(self.shap_values)
@property
def shape(self) -> tuple:
"""Get shape of SHAP values array."""
return self.values.shape
@property
def mean_abs_shap(self) -> np.ndarray:
"""Get mean absolute SHAP values per feature.
Returns:
1D array of mean |SHAP| values, one per feature.
"""
vals = self.values
if vals.ndim == 1:
return np.abs(vals)
return np.mean(np.abs(vals), axis=0)
@property
def top_features(self) -> List[str]:
"""Get feature names sorted by importance (descending).
Returns:
List of feature names, most important first.
Returns indices as strings if feature_names not available.
"""
importance = self.mean_abs_shap
sorted_indices = np.argsort(importance)[::-1]
if self.feature_names:
return [self.feature_names[i] for i in sorted_indices]
return [str(i) for i in sorted_indices]
[docs]
def get_feature_importance(
self,
top_n: Optional[int] = None,
normalize: bool = False
) -> Dict[str, float]:
"""Get feature importance ranking.
Args:
top_n: If provided, return only top N features.
normalize: If True, normalize values to sum to 1.
Returns:
Dictionary mapping feature names to importance values.
"""
importance = self.mean_abs_shap
if normalize and importance.sum() > 0:
importance = importance / importance.sum()
sorted_indices = np.argsort(importance)[::-1]
if top_n:
sorted_indices = sorted_indices[:top_n]
result = {}
for idx in sorted_indices:
name = self.feature_names[idx] if self.feature_names else str(idx)
result[name] = float(importance[idx])
return result
[docs]
def get_sample_explanation(
self,
idx: int
) -> Dict[str, float]:
"""Get SHAP explanation for a single sample.
Args:
idx: Sample index.
Returns:
Dictionary mapping feature names to SHAP values for that sample.
"""
vals = self.values
if idx >= len(vals):
raise IndexError(f"Sample index {idx} out of range (n_samples={len(vals)})")
sample_shap = vals[idx] if vals.ndim > 1 else vals
result = {}
for i, val in enumerate(sample_shap):
name = self.feature_names[i] if self.feature_names else str(i)
result[name] = float(val)
return result
[docs]
def to_dataframe(self, include_feature_names: bool = True):
"""Get SHAP values as pandas DataFrame.
Args:
include_feature_names: If True, use feature names as columns.
Returns:
pandas DataFrame with SHAP values.
Raises:
ImportError: If pandas is not available.
"""
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required for to_dataframe()")
vals = self.values
if include_feature_names and self.feature_names:
columns = self.feature_names
else:
columns = [f"feature_{i}" for i in range(vals.shape[-1])]
if vals.ndim == 1:
vals = vals.reshape(1, -1)
return pd.DataFrame(vals, columns=columns)
[docs]
def __repr__(self) -> str:
"""String representation."""
return f"ExplainResult(shape={self.shape}, explainer='{self.explainer_type}')"
[docs]
def __str__(self) -> str:
"""User-friendly string representation."""
lines = [f"ExplainResult: {self.n_samples} samples explained"]
if self.model_name:
lines.append(f" Model: {self.model_name}")
lines.append(f" Explainer: {self.explainer_type}")
lines.append(f" Shape: {self.shape}")
if self.feature_names:
lines.append(f" Features: {len(self.feature_names)}")
# Show top 5 features
top = self.top_features[:5]
if top:
lines.append(f" Top features: {', '.join(top)}")
if self.visualizations:
lines.append(f" Visualizations: {list(self.visualizations.keys())}")
return "\n".join(lines)