Source code for nirs4all.visualization.chart_utils.predictions_adapter

"""
PredictionsAdapter - Adapter for Predictions API with optimized data access.

Wraps the refactored Predictions API to provide convenient methods for charts.
"""
from typing import List, Optional
from nirs4all.data.predictions import PredictionResultsList


[docs] class PredictionsAdapter: """Adapter for Predictions API with optimized data access. Wraps the refactored Predictions API to provide convenient methods for charts. Leverages predictions.top(), lazy loading, and structured results. Key Optimizations: - Uses predictions.top() for efficient ranking - Supports lazy loading (load_arrays=False) for metadata-only queries - Works with PredictionResult/PredictionResultsList classes - Avoids redundant metric calculations Attributes: predictions: Predictions object instance. """ def __init__(self, predictions): """Initialize adapter with predictions object. Args: predictions: Predictions object instance. """ self.predictions = predictions
[docs] def get_top_models( self, n: int, rank_metric: str, rank_partition: str = 'val', ascending: Optional[bool] = None, load_arrays: bool = True, **filters ) -> PredictionResultsList: """Get top N models using predictions.top() API. Args: n: Number of top models to retrieve. rank_metric: Metric to rank by. rank_partition: Partition to rank on (default: 'val'). ascending: Sort order (None = auto-detect from metric). load_arrays: Whether to load prediction arrays (default: True). **filters: Additional filters (dataset_name, model_name, etc.). Returns: PredictionResultsList of top N models. """ if ascending is None: ascending = not self.is_higher_better(rank_metric) return self.predictions.top( n=n, rank_metric=rank_metric, rank_partition=rank_partition, ascending=ascending, load_arrays=load_arrays, **filters )
[docs] def get_all_predictions_metadata( self, rank_metric: str = 'rmse', rank_partition: str = 'test', **filters ) -> PredictionResultsList: """Get all predictions matching filters (metadata only, fast). Args: rank_metric: Metric for sorting (default: 'rmse'). rank_partition: Partition for sorting (default: 'test'). **filters: Filters to apply (dataset_name, model_name, etc.). Returns: PredictionResultsList with all matching predictions (no arrays loaded). """ return self.predictions.top( n=self.predictions.num_predictions, rank_metric=rank_metric, rank_partition=rank_partition, load_arrays=False, **filters )
[docs] def extract_metric_values( self, predictions_list: PredictionResultsList, metric: str, partition: str = 'test' ) -> List[float]: """Extract metric values from prediction results. Args: predictions_list: List of prediction results. metric: Metric name to extract. partition: Partition to extract from (default: 'test'). Returns: List of metric values. """ values = [] for pred in predictions_list: try: # Try to get the score from the partition-specific field score_field = f'{partition}_score' if score_field in pred: values.append(float(pred[score_field])) elif 'metrics' in pred and metric in pred['metrics']: values.append(float(pred['metrics'][metric])) elif metric in pred: values.append(float(pred[metric])) except (KeyError, TypeError, ValueError): continue return values
[docs] @staticmethod def is_higher_better(metric: str) -> bool: """Check if metric is higher-is-better. Args: metric: Metric name. Returns: True if higher is better, False otherwise. """ return metric.lower() in ['r2', 'accuracy', 'f1', 'precision', 'recall', 'auc']