nirs4all.visualization.predictions module
PredictionAnalyzer - Orchestrator for prediction analysis and visualization.
This module provides a unified interface for creating various prediction visualizations. Delegates to specialized chart classes for rendering.
Leverages the refactored Predictions API (predictions.top(), PredictionResult, etc.) for efficient data access and avoids redundant calculations.
Includes a caching layer (PredictionCache) to avoid recomputing expensive aggregations when multiple charts use the same parameters.
- class nirs4all.visualization.predictions.PredictionAnalyzer(predictions_obj: Predictions, dataset_name_override: str | None = None, config: ChartConfig | None = None, output_dir: str | None = None, cache_size: int = 50, default_aggregate: str | None = None, default_aggregate_method: str | None = None, default_aggregate_exclude_outliers: bool = False)[source]
Bases:
objectOrchestrator for prediction analysis and visualization.
Provides a unified interface for creating various prediction visualizations. Delegates to specialized chart classes for rendering.
Includes a caching layer (PredictionCache) to avoid recomputing expensive aggregations when multiple charts use the same parameters. The cache is keyed by (aggregate, rank_metric, rank_partition, display_partition, group_by, filters) and stores the results of predictions.top() calls.
Leverages the refactored Predictions API (predictions.top(), PredictionResult, etc.) for efficient data access and avoids redundant calculations.
- predictions
Predictions object containing prediction data.
- dataset_name_override
Optional dataset name override for display.
- config
ChartConfig for customization across all charts.
- output_dir
Directory to save generated charts.
- cache
PredictionCache for caching aggregated results.
- default_aggregate
Default aggregation column for all visualization methods.
Example
>>> from nirs4all.data.predictions import Predictions >>> predictions = Predictions.load('predictions.json') >>> analyzer = PredictionAnalyzer(predictions) >>> >>> # Plot top 5 models - first call computes aggregation >>> fig = analyzer.plot_top_k(k=5, aggregate='ID') >>> >>> # Plot heatmap - uses cached aggregation (fast!) >>> fig = analyzer.plot_heatmap('model_name', 'preprocessings', aggregate='ID') >>> >>> # Check cache stats >>> print(analyzer.get_cache_stats()) >>> >>> # With default aggregation from dataset config >>> runner = PipelineRunner() >>> predictions, _ = runner.run(pipeline, DatasetConfigs(path, aggregate='sample_id')) >>> analyzer = PredictionAnalyzer(predictions, default_aggregate=runner.last_aggregate) >>> # All plots now use sample_id aggregation by default >>> fig = analyzer.plot_top_k(k=5) # Aggregated automatically
- branch_summary(metrics: List[str] | None = None, display_partition: str = 'test', aggregate: str | None = None, as_dataframe: bool = True, **filters) DataFrame | Dict[str, Dict[str, Any]][source]
Generate summary statistics comparing branch performance.
Computes mean, std, min, max for each metric across branches.
- Parameters:
metrics – List of metrics to compute (default: [‘rmse’, ‘r2’] or [‘balanced_accuracy’, ‘f1’] for classification).
display_partition – Partition to compute metrics from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column (e.g., ‘ID’) before computing statistics.
as_dataframe – If True, return pandas DataFrame. If False, return dict.
**filters – Additional filter criteria.
- Returns:
branch_name: Branch identifier
branch_id: Numeric branch ID
count: Number of predictions
{metric}_mean: Mean value
{metric}_std: Standard deviation
{metric}_min: Minimum value
{metric}_max: Maximum value
- Return type:
DataFrame or dict with branch summary statistics
Examples
>>> summary = analyzer.branch_summary(metrics=['rmse', 'r2']) >>> print(summary.to_markdown())
>>> summary = analyzer.branch_summary( ... metrics=['balanced_accuracy', 'f1'], ... aggregate='ID' ... )
- clear_cache() None[source]
Clear all caches.
Call this if the underlying predictions data has been modified to ensure fresh results are computed. Clears both: - Analyzer’s query result cache - Ranker’s aggregation and score caches
- generate_report(output_path: str, branch_comparison: bool = True, include_diagrams: bool = True, include_tables: bool = True, metrics: List[str] | None = None, partition: str = 'test', title: str | None = None) str[source]
Generate HTML report with branch analysis.
Creates a comprehensive HTML report with branch comparisons, visualizations, and statistical tables.
- Parameters:
output_path – Path for the output HTML file.
branch_comparison – If True, include branch comparison section.
include_diagrams – If True, include branch diagram visualization.
include_tables – If True, include summary statistics tables.
metrics – List of metrics to include (default: [‘rmse’, ‘r2’]).
partition – Partition for metrics (default: ‘test’).
title – Report title (default: ‘Branch Comparison Report’).
- Returns:
Path to the generated HTML file.
Examples
>>> path = analyzer.generate_report( ... 'reports/branch_comparison.html', ... branch_comparison=True, ... metrics=['rmse', 'r2', 'mae'] ... )
- get_branch_ids() List[int][source]
Get list of unique branch IDs in predictions.
- Returns:
List of branch IDs (empty list if no branches)
Examples
>>> branch_ids = analyzer.get_branch_ids() >>> print(branch_ids) # [0, 1, 2]
- get_branches() List[str][source]
Get list of unique branch names in predictions.
- Returns:
List of branch names (empty list if no branches)
Examples
>>> branches = analyzer.get_branches() >>> print(branches) # ['snv_pca', 'msc_detrend', 'derivative']
- get_cache_stats() Dict[str, Any][source]
Get cache performance statistics.
- Returns:
analyzer_cache: Query result cache stats
ranker_cache: Aggregation and score cache stats
- Return type:
Dictionary with stats for both analyzer and ranker caches
- get_cached_predictions(n: int, rank_metric: str, rank_partition: str = 'val', display_partition: str = 'test', display_metrics: List[str] | None = None, aggregate: str | None = None, aggregate_method: str | None = None, aggregate_exclude_outliers: bool | None = None, group_by: str | List[str] | None = None, aggregate_partitions: bool = True, **filters)[source]
Get predictions with caching support.
This method wraps predictions.top() with a caching layer. Charts should call this method instead of directly calling predictions.top() to benefit from caching.
The cache key includes: aggregate, rank_metric, rank_partition, display_partition, group_by, and all filters.
- Parameters:
n – Number of top predictions to return.
rank_metric – Metric for ranking.
rank_partition – Partition for ranking (default: ‘val’).
display_partition – Partition for display (default: ‘test’).
display_metrics – List of metrics to compute for display.
aggregate – Aggregation column (e.g., ‘ID’) or None.
aggregate_method – Aggregation method (‘mean’, ‘median’, ‘vote’). If None, uses default_aggregate_method from constructor.
aggregate_exclude_outliers – If True, exclude outliers using T² before aggregation. If None, uses default_aggregate_exclude_outliers from constructor.
group_by – Grouping column(s) for deduplication.
aggregate_partitions – If True, include all partition data.
**filters – Additional filter criteria.
- Returns:
PredictionResultsList from cache or fresh computation.
Example
>>> # First call computes and caches >>> preds = analyzer.get_cached_predictions( ... n=5, rank_metric='rmse', aggregate='ID' ... ) >>> # Second call with same params is instant >>> preds = analyzer.get_cached_predictions( ... n=5, rank_metric='rmse', aggregate='ID' ... )
- plot_branch_boxplot(rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot boxplot comparing score distributions across branches.
Creates a boxplot showing the distribution of metric values for each branch.
- Parameters:
rank_metric – Metric for ranking models (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display results from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
figsize – Figure size tuple (default: auto-computed).
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with branch comparison boxplot.
Examples
>>> fig = analyzer.plot_branch_boxplot(display_metric='rmse') >>> fig = analyzer.plot_branch_boxplot( ... display_metric='r2', ... aggregate='ID' ... )
- plot_branch_comparison(rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, show_ci: bool = True, ci_level: float = 0.95, figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot bar chart comparing branch performance with confidence intervals.
Creates a grouped bar chart showing mean metric values for each branch with optional confidence intervals.
- Parameters:
rank_metric – Metric for ranking models (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display results from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
show_ci – If True, show confidence intervals (default: True).
ci_level – Confidence level for intervals (default: 0.95).
figsize – Figure size tuple (default: auto-computed).
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with branch comparison bar chart.
Examples
>>> fig = analyzer.plot_branch_comparison(display_metric='rmse') >>> fig = analyzer.plot_branch_comparison( ... display_metric='r2', ... aggregate='ID', ... show_ci=True ... )
- plot_branch_diagram(show_metrics: bool = True, metric: str | None = None, partition: str = 'test', figsize: tuple | None = None, title: str | None = None, config: Dict[str, Any] | None = None) Figure[source]
Plot DAG diagram showing the branching structure of the pipeline.
Creates a visual diagram showing shared steps, branch nodes, and post-branch models in a hierarchical layout.
- Parameters:
show_metrics – If True, show metrics in branch nodes (default: True).
metric – Metric to display (default: auto-detect).
partition – Partition for metrics (default: ‘test’).
figsize – Figure size tuple (default: auto-computed).
title – Optional title for the diagram.
config – Additional configuration dict for BranchDiagram.
- Returns:
matplotlib Figure with branch DAG diagram.
Examples
>>> fig = analyzer.plot_branch_diagram(metric='rmse') >>> fig = analyzer.plot_branch_diagram( ... show_metrics=True, ... metric='r2', ... partition='val' ... )
- plot_branch_heatmap(y_var: str = 'fold_id', rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot heatmap of branch performance across folds or other variable.
Creates a heatmap with branches on x-axis and another variable (e.g., fold_id) on y-axis.
- Parameters:
y_var – Variable for y-axis (default: ‘fold_id’).
rank_metric – Metric for ranking (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
config – Optional ChartConfig to override defaults.
**kwargs – Additional parameters passed to plot_heatmap.
- Returns:
matplotlib Figure with branch heatmap.
Examples
>>> fig = analyzer.plot_branch_heatmap(display_metric='rmse') >>> fig = analyzer.plot_branch_heatmap( ... y_var='model_name', ... display_metric='r2' ... )
- plot_candlestick(variable: str, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot candlestick chart for score distribution by variable.
- Parameters:
variable – Variable to group by (e.g., ‘model_name’, ‘preprocessings’).
display_metric – Metric to analyze (default: auto-detect from task type).
display_partition – Partition to display scores from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object.
Example
>>> fig = analyzer.plot_candlestick('model_name', display_metric='rmse') >>> fig = analyzer.plot_candlestick('model_name', display_metric='rmse', aggregate='ID')
- plot_confusion_matrix(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str | List[str] = '', display_partition: str = 'test', show_scores: bool = True, aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot confusion matrices for top K classification models.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
k – Number of top models to show (default: 5).
rank_metric – Metric for ranking (default: auto-detect from task type).
rank_partition – Partition used for ranking models (default: ‘val’).
display_metric – Metric(s) to display in titles. Can be a single string (e.g., ‘accuracy’) or a list of strings for multiple metrics (e.g., [‘balanced_accuracy’, ‘accuracy’]). Metric names are shown in abbreviated form (default: same as rank_metric).
display_partition – Partition to display confusion matrix from (default: ‘test’).
show_scores – If True, show scores in chart titles (default: True).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_confusion_matrix(k=3, rank_metric='f1') >>> fig = analyzer.plot_confusion_matrix(k=3, aggregate='ID') >>> # Multiple metrics displayed with abbreviated names >>> fig = analyzer.plot_confusion_matrix( ... k=3, ... display_metric=['balanced_accuracy', 'accuracy'] ... )
- plot_heatmap(x_var: str, y_var: str, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'test', normalize: bool = False, rank_agg: str = 'best', display_agg: str = 'best', show_counts: bool = True, local_scale: bool = False, column_scale: bool = False, aggregate: str | None = None, top_k: int | None = None, sort_by_value: bool = False, sort_by: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot performance heatmap across two variables.
For each (x_var, y_var) cell: 1. Rank predictions by rank_metric on rank_partition using rank_agg 2. Display display_metric from display_partition using display_agg 3. Normalize per dataset if requested 4. Show counts if requested
- Parameters:
x_var – Variable for x-axis (e.g., ‘model_name’, ‘preprocessings’).
y_var – Variable for y-axis (e.g., ‘dataset_name’, ‘partition’).
rank_metric – Metric used to rank/select models (default: auto-detect from task type).
rank_partition – Partition used for ranking models (default: ‘val’).
display_metric – Metric to display in heatmap (default: same as rank_metric).
display_partition – Partition to display scores from (default: ‘test’).
normalize – If True, show normalized scores in cells. Colors always use normalized (default: False).
rank_agg – Aggregation for ranking (‘best’, ‘worst’, ‘mean’, ‘median’) (default: ‘best’).
display_agg – Aggregation for display scores (‘best’, ‘worst’, ‘mean’, ‘median’) (default: ‘mean’).
show_counts – Show prediction counts in cells (default: True).
local_scale – If True, colorbar shows actual metric values; if False, shows 0-1 normalized (default: False).
column_scale – If True, normalize colors per column (best in column = 1.0). Automatically sets local_scale=False when enabled (default: False).
aggregate – If provided, aggregate predictions by this metadata column (e.g., ‘ID’).
top_k – If provided, show only top K models. Selection uses Borda count: first keeps top-1 per column, then ranks by Borda count.
sort_by_value – If True, sort Y-axis by ranking score (best first) instead of alphabetically. Uses rank_metric on rank_partition. Deprecated: use sort_by=’value’ instead.
sort_by – Sorting method for Y-axis (rows). Options: - None: Alphabetical sorting (default). - ‘value’: Sort by ranking score on rank_partition column. - ‘mean’: Sort by mean score across all columns. - ‘median’: Sort by median score across all columns. - ‘borda’: Sort by Borda count (sum of ranks across columns). - ‘condorcet’: Sort by pairwise wins (Copeland method). - ‘consensus’: Sort by consensus (geometric mean of normalized ranks).
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional filters (dataset_name, model_name, etc.).
- Returns:
matplotlib Figure object.
Example
>>> # Rank on best val RMSE, display mean test RMSE >>> fig = analyzer.plot_heatmap('model_name', 'dataset_name') >>> >>> # Rank on mean val R2, display best test F1 >>> fig = analyzer.plot_heatmap( ... 'model_name', 'dataset_name', ... rank_metric='r2', ... rank_agg='mean', ... display_metric='f1', ... display_agg='best' ... ) >>> >>> # Use column normalization for comparing across partitions >>> fig = analyzer.plot_heatmap( ... 'partition', 'model_name', ... column_scale=True ... )
- plot_histogram(display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot score distribution histogram.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
display_metric – Metric to plot (default: auto-detect from task type).
display_partition – Partition to display scores from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, bins, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_histogram(display_metric='r2', display_partition='val') >>> fig = analyzer.plot_histogram(display_metric='rmse', aggregate='ID')
- plot_nested_branches(level1_var: str = 'branch_path_level1', level2_var: str = 'branch_path_level2', metric: str | None = None, partition: str = 'test', plot_type: str = 'grouped_bar', figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot nested branch comparison for hierarchical experiments.
Creates grouped bar charts or faceted plots for nested branch structures.
- Parameters:
level1_var – Variable for first level grouping (outer group).
level2_var – Variable for second level grouping (inner group/x-axis).
metric – Metric to display (default: auto-detect).
partition – Partition for metrics (default: ‘test’).
plot_type – Type of plot (‘grouped_bar’, ‘facet’).
figsize – Figure size tuple.
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with nested branch visualization.
Examples
>>> # Compare outlier strategies × preprocessing >>> fig = analyzer.plot_nested_branches( ... level1_var='outlier_strategy', ... level2_var='preprocessing', ... metric='rmse' ... )
- plot_top_k(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'all', show_scores: bool = True, aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot top K model comparison (scatter + residuals).
Models are ranked by rank_metric on rank_partition, then predictions from display_partition(s) are shown.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
k – Number of top models to show (default: 5).
rank_metric – Metric for ranking models (default: auto-detect from task type).
rank_partition – Partition used for ranking (default: ‘val’).
display_metric – Metric to display in titles (default: same as rank_metric).
display_partition – Partition(s) to display (‘all’ or specific partition).
show_scores – If True, show scores in chart titles (default: True).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_top_k(k=3, rank_metric='r2') >>> fig = analyzer.plot_top_k(k=3, aggregate='ID') # Aggregated by ID