nirs4all.visualization package
Subpackages
- nirs4all.visualization.analysis package
- Submodules
- Module contents
BranchAnalyzerBranchSummaryPreprocPCAEvaluatorPreprocPCAEvaluator.fit()PreprocPCAEvaluator.get_cross_dataset_summary()PreprocPCAEvaluator.get_quality_metric_convergence()PreprocPCAEvaluator.plot_all()PreprocPCAEvaluator.plot_all_datasets_pca()PreprocPCAEvaluator.plot_cross_dataset_distances()PreprocPCAEvaluator.plot_cross_dataset_heatmap()PreprocPCAEvaluator.plot_distance_matrices()PreprocPCAEvaluator.plot_distance_reduction_ranking()PreprocPCAEvaluator.plot_pair()PreprocPCAEvaluator.plot_preservation_summary()PreprocPCAEvaluator.plot_quality_metric_convergence()
ShapAnalyzerShapAnalyzer.explain_model()ShapAnalyzer.get_feature_importance()ShapAnalyzer.load_results()ShapAnalyzer.plot_beeswarm()ShapAnalyzer.plot_beeswarm_binned()ShapAnalyzer.plot_dependence()ShapAnalyzer.plot_force()ShapAnalyzer.plot_spectral_importance()ShapAnalyzer.plot_summary()ShapAnalyzer.plot_waterfall()ShapAnalyzer.plot_waterfall_binned()ShapAnalyzer.save_results()
- nirs4all.visualization.chart_utils package
- Submodules
- Module contents
- nirs4all.visualization.charts package
- Submodules
- nirs4all.visualization.charts.base module
- nirs4all.visualization.charts.candlestick module
- nirs4all.visualization.charts.config module
- nirs4all.visualization.charts.confusion_matrix module
- nirs4all.visualization.charts.heatmap module
- nirs4all.visualization.charts.histogram module
- nirs4all.visualization.charts.top_k_comparison module
- Module contents
BaseChartCandlestickChartChartConfigChartConfig.colormapChartConfig.heatmap_colormapChartConfig.partition_colorsChartConfig.font_familyChartConfig.title_fontsizeChartConfig.label_fontsizeChartConfig.tick_fontsizeChartConfig.legend_fontsizeChartConfig.annotation_fontsizeChartConfig.figsize_smallChartConfig.figsize_mediumChartConfig.figsize_largeChartConfig.dpiChartConfig.alphaChartConfig.__post_init__()ChartConfig.alphaChartConfig.annotation_fontsizeChartConfig.apply_font_settings()ChartConfig.colormapChartConfig.dpiChartConfig.figsize_largeChartConfig.figsize_mediumChartConfig.figsize_smallChartConfig.font_familyChartConfig.get_figsize()ChartConfig.heatmap_colormapChartConfig.label_fontsizeChartConfig.legend_fontsizeChartConfig.partition_colorsChartConfig.tick_fontsizeChartConfig.title_fontsize
ConfusionMatrixChartHeatmapChartScoreHistogramChartTopKComparisonChart
- Submodules
Submodules
- nirs4all.visualization.branch_diagram module
- nirs4all.visualization.pipeline_diagram module
BranchDiagramPipelineDiagramPipelineNodePipelineNode.idPipelineNode.step_indexPipelineNode.labelPipelineNode.node_typePipelineNode.shape_beforePipelineNode.shape_afterPipelineNode.input_layout_shapePipelineNode.output_layout_shapePipelineNode.features_shapePipelineNode.branch_idPipelineNode.branch_namePipelineNode.substep_indexPipelineNode.parent_idsPipelineNode.children_idsPipelineNode.duration_msPipelineNode.metadataPipelineNode.branch_idPipelineNode.branch_namePipelineNode.children_idsPipelineNode.duration_msPipelineNode.features_shapePipelineNode.idPipelineNode.input_layout_shapePipelineNode.labelPipelineNode.metadataPipelineNode.node_typePipelineNode.output_layout_shapePipelineNode.parent_idsPipelineNode.shape_afterPipelineNode.shape_beforePipelineNode.step_indexPipelineNode.substep_index
plot_branch_diagram()plot_pipeline_diagram()
- nirs4all.visualization.prediction_cache module
- nirs4all.visualization.predictions module
PredictionAnalyzerPredictionAnalyzer.predictionsPredictionAnalyzer.dataset_name_overridePredictionAnalyzer.configPredictionAnalyzer.output_dirPredictionAnalyzer.cachePredictionAnalyzer.default_aggregatePredictionAnalyzer.branch_summary()PredictionAnalyzer.clear_cache()PredictionAnalyzer.generate_report()PredictionAnalyzer.get_branch_ids()PredictionAnalyzer.get_branches()PredictionAnalyzer.get_cache_stats()PredictionAnalyzer.get_cached_predictions()PredictionAnalyzer.plot_branch_boxplot()PredictionAnalyzer.plot_branch_comparison()PredictionAnalyzer.plot_branch_diagram()PredictionAnalyzer.plot_branch_heatmap()PredictionAnalyzer.plot_candlestick()PredictionAnalyzer.plot_confusion_matrix()PredictionAnalyzer.plot_heatmap()PredictionAnalyzer.plot_histogram()PredictionAnalyzer.plot_nested_branches()PredictionAnalyzer.plot_top_k()
- nirs4all.visualization.reports module
Module contents
Visualization tools for NIRS data analysis.
- class nirs4all.visualization.BranchAnalyzer(predictions)[source]
Bases:
objectAnalyze and compare performance across pipeline branches.
Provides statistical analysis, hypothesis testing, and comparison tools for branched pipeline results.
- predictions
Predictions object containing prediction data.
- compare(branch1: str | int, branch2: str | int, metric: str = 'rmse', partition: str = 'test', test: str = 'ttest') Dict[str, Any][source]
Statistical comparison between two branches.
Performs hypothesis testing to determine if there’s a significant difference between two branches.
- Parameters:
branch1 – First branch name or ID.
branch2 – Second branch name or ID.
metric – Metric to compare (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
test – Statistical test (‘ttest’, ‘wilcoxon’, ‘mannwhitney’).
- Returns:
statistic: Test statistic
p_value: P-value
significant: Boolean at alpha=0.05
branch1_mean: Mean of branch1
branch2_mean: Mean of branch2
effect_size: Cohen’s d effect size
- Return type:
Dictionary with
- Raises:
ImportError – If scipy is not available.
ValueError – If branches not found or insufficient data.
- get_branch_names() List[str][source]
Get list of unique branch names.
- Returns:
List of branch names.
- pairwise_comparison(metric: str = 'rmse', partition: str = 'test', test: str = 'ttest') DataFrame[source]
Compute pairwise statistical comparisons between all branches.
- Parameters:
metric – Metric to compare (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
test – Statistical test to use.
- Returns:
DataFrame with p-values for all branch pairs.
- Raises:
ImportError – If pandas or scipy not available.
- rank_branches(metric: str = 'rmse', partition: str = 'test', ascending: bool | None = None) List[Dict[str, Any]][source]
Rank branches by mean performance.
- Parameters:
metric – Metric to rank by (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
ascending – Sort order. If None, auto-detect based on metric.
- Returns:
List of dicts with branch_name, mean, std, rank.
- summary(metrics: List[str] | None = None, partition: str = 'test', aggregate: str | None = None) BranchSummary[source]
Generate summary statistics for each branch.
Computes mean, std, min, max for each metric across branches.
- Parameters:
metrics – List of metrics to compute (default: [‘rmse’, ‘r2’]).
partition – Partition to compute metrics from (default: ‘test’).
aggregate – If provided, aggregate predictions by this column before computing statistics.
- Returns:
BranchSummary object with statistics.
- class nirs4all.visualization.BranchDiagram(predictions: Any = None, config: Dict[str, Any] | None = None)[source]
Bases:
PipelineDiagramDEPRECATED: Use PipelineDiagram instead.
This class is kept for backward compatibility only. It wraps PipelineDiagram with the old API.
- render(show_metrics: bool | None = None, metric: str | None = None, partition: str | None = None, figsize: Tuple[float, float] | None = None, title: str | None = None) Figure[source]
Render the branch diagram (deprecated API).
- Parameters:
show_metrics – Override config’s show_metrics setting.
metric – Override metric to display.
partition – Override partition for metrics.
figsize – Override figure size.
title – Optional title for the diagram.
- Returns:
matplotlib Figure object.
- class nirs4all.visualization.BranchSummary(data: List[Dict[str, Any]], metrics: List[str])[source]
Bases:
objectBranch summary statistics container with export capabilities.
Provides DataFrame-like access and export to markdown, LaTeX, and CSV.
- data
List of dictionaries with branch statistics.
- metrics
List of metrics computed.
- columns
Column names in order.
- __getitem__(key: int | str) Dict[str, Any][source]
Get branch by index or name.
- Parameters:
key – Integer index or branch name string.
- Returns:
Dictionary with branch statistics.
- to_csv(path: str, precision: int = 6) None[source]
Export to CSV file.
- Parameters:
path – Output file path.
precision – Decimal places for floating point values.
- to_dataframe() DataFrame[source]
Convert to pandas DataFrame.
- Returns:
pandas DataFrame with branch statistics.
- Raises:
ImportError – If pandas is not installed.
- to_dict() Dict[str, Dict[str, Any]][source]
Convert to dictionary keyed by branch name.
- Returns:
Dictionary mapping branch_name to statistics.
- to_latex(caption: str = 'Branch Performance Comparison', label: str = 'tab:branch_comparison', precision: int = 3, include_std: bool = True, mean_std_combined: bool = True) str[source]
Export as LaTeX table for publications.
- Parameters:
caption – Table caption.
label – LaTeX label for referencing.
precision – Decimal places for floating point values.
include_std – If True, include std columns.
mean_std_combined – If True, format as “mean ± std”.
- Returns:
LaTeX-formatted table string.
- class nirs4all.visualization.PipelineDiagram(pipeline_steps: List[Any] | None = None, predictions: Any = None, execution_trace: Any = None, config: Dict[str, Any] | None = None)[source]
Bases:
objectCreate DAG visualization for pipeline execution structure.
Renders a visual diagram showing the complete pipeline topology, including all steps, shapes, branches, and models.
- pipeline_steps
List of pipeline step definitions
- predictions
Optional Predictions object with execution data
- execution_trace
Optional ExecutionTrace with actual runtime shapes
- config
Optional dict for customization
- NODE_STYLES = {'branch': ('#E0F2F1', '#00796B'), 'concat_transform': ('#F3E5F5', '#7B1FA2'), 'default': ('#ECEFF1', '#455A64'), 'feature_augmentation': ('#E0F2F1', '#00796B'), 'input': ('#FAFAFA', '#616161'), 'merge': ('#E0F2F1', '#00796B'), 'merge_sources': ('#E0F2F1', '#00796B'), 'model': ('#FFEBEE', '#D32F2F'), 'output': ('#FAFAFA', '#616161'), 'preprocessing': ('#E3F2FD', '#1976D2'), 'sample_augmentation': ('#E8F5E9', '#388E3C'), 'source_branch': ('#E0F2F1', '#00796B'), 'splitter': ('#F3E5F5', '#7B1FA2'), 'y_processing': ('#FFF8E1', '#FFA000')}
- classmethod from_trace(execution_trace: Any, predictions: Any = None, config: Dict[str, Any] | None = None) PipelineDiagram[source]
Create a PipelineDiagram from an ExecutionTrace.
This builds the diagram using actual runtime data including measured shapes at each step.
- Parameters:
execution_trace – ExecutionTrace object from pipeline execution
predictions – Optional Predictions object to enrich nodes with scores
config – Optional configuration dict
- Returns:
PipelineDiagram instance ready for rendering
Example
>>> from nirs4all.visualization import PipelineDiagram >>> diagram = PipelineDiagram.from_trace(trace) >>> fig = diagram.render(title="Execution Trace")
- render(show_shapes: bool | None = None, figsize: Tuple[float, float] | None = None, title: str | None = None, initial_shape: Tuple[int, int, int] | None = None) Figure[source]
Render the pipeline diagram.
- Parameters:
show_shapes – Override config’s show_shapes setting
figsize – Override figure size
title – Optional title for the diagram
initial_shape – Initial dataset shape (samples, processings, features)
- Returns:
matplotlib Figure object
- class nirs4all.visualization.PredictionAnalyzer(predictions_obj: Predictions, dataset_name_override: str | None = None, config: ChartConfig | None = None, output_dir: str | None = None, cache_size: int = 50, default_aggregate: str | None = None, default_aggregate_method: str | None = None, default_aggregate_exclude_outliers: bool = False)[source]
Bases:
objectOrchestrator for prediction analysis and visualization.
Provides a unified interface for creating various prediction visualizations. Delegates to specialized chart classes for rendering.
Includes a caching layer (PredictionCache) to avoid recomputing expensive aggregations when multiple charts use the same parameters. The cache is keyed by (aggregate, rank_metric, rank_partition, display_partition, group_by, filters) and stores the results of predictions.top() calls.
Leverages the refactored Predictions API (predictions.top(), PredictionResult, etc.) for efficient data access and avoids redundant calculations.
- predictions
Predictions object containing prediction data.
- dataset_name_override
Optional dataset name override for display.
- config
ChartConfig for customization across all charts.
- output_dir
Directory to save generated charts.
- cache
PredictionCache for caching aggregated results.
- default_aggregate
Default aggregation column for all visualization methods.
Example
>>> from nirs4all.data.predictions import Predictions >>> predictions = Predictions.load('predictions.json') >>> analyzer = PredictionAnalyzer(predictions) >>> >>> # Plot top 5 models - first call computes aggregation >>> fig = analyzer.plot_top_k(k=5, aggregate='ID') >>> >>> # Plot heatmap - uses cached aggregation (fast!) >>> fig = analyzer.plot_heatmap('model_name', 'preprocessings', aggregate='ID') >>> >>> # Check cache stats >>> print(analyzer.get_cache_stats()) >>> >>> # With default aggregation from dataset config >>> runner = PipelineRunner() >>> predictions, _ = runner.run(pipeline, DatasetConfigs(path, aggregate='sample_id')) >>> analyzer = PredictionAnalyzer(predictions, default_aggregate=runner.last_aggregate) >>> # All plots now use sample_id aggregation by default >>> fig = analyzer.plot_top_k(k=5) # Aggregated automatically
- branch_summary(metrics: List[str] | None = None, display_partition: str = 'test', aggregate: str | None = None, as_dataframe: bool = True, **filters) DataFrame | Dict[str, Dict[str, Any]][source]
Generate summary statistics comparing branch performance.
Computes mean, std, min, max for each metric across branches.
- Parameters:
metrics – List of metrics to compute (default: [‘rmse’, ‘r2’] or [‘balanced_accuracy’, ‘f1’] for classification).
display_partition – Partition to compute metrics from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column (e.g., ‘ID’) before computing statistics.
as_dataframe – If True, return pandas DataFrame. If False, return dict.
**filters – Additional filter criteria.
- Returns:
branch_name: Branch identifier
branch_id: Numeric branch ID
count: Number of predictions
{metric}_mean: Mean value
{metric}_std: Standard deviation
{metric}_min: Minimum value
{metric}_max: Maximum value
- Return type:
DataFrame or dict with branch summary statistics
Examples
>>> summary = analyzer.branch_summary(metrics=['rmse', 'r2']) >>> print(summary.to_markdown())
>>> summary = analyzer.branch_summary( ... metrics=['balanced_accuracy', 'f1'], ... aggregate='ID' ... )
- clear_cache() None[source]
Clear all caches.
Call this if the underlying predictions data has been modified to ensure fresh results are computed. Clears both: - Analyzer’s query result cache - Ranker’s aggregation and score caches
- generate_report(output_path: str, branch_comparison: bool = True, include_diagrams: bool = True, include_tables: bool = True, metrics: List[str] | None = None, partition: str = 'test', title: str | None = None) str[source]
Generate HTML report with branch analysis.
Creates a comprehensive HTML report with branch comparisons, visualizations, and statistical tables.
- Parameters:
output_path – Path for the output HTML file.
branch_comparison – If True, include branch comparison section.
include_diagrams – If True, include branch diagram visualization.
include_tables – If True, include summary statistics tables.
metrics – List of metrics to include (default: [‘rmse’, ‘r2’]).
partition – Partition for metrics (default: ‘test’).
title – Report title (default: ‘Branch Comparison Report’).
- Returns:
Path to the generated HTML file.
Examples
>>> path = analyzer.generate_report( ... 'reports/branch_comparison.html', ... branch_comparison=True, ... metrics=['rmse', 'r2', 'mae'] ... )
- get_branch_ids() List[int][source]
Get list of unique branch IDs in predictions.
- Returns:
List of branch IDs (empty list if no branches)
Examples
>>> branch_ids = analyzer.get_branch_ids() >>> print(branch_ids) # [0, 1, 2]
- get_branches() List[str][source]
Get list of unique branch names in predictions.
- Returns:
List of branch names (empty list if no branches)
Examples
>>> branches = analyzer.get_branches() >>> print(branches) # ['snv_pca', 'msc_detrend', 'derivative']
- get_cache_stats() Dict[str, Any][source]
Get cache performance statistics.
- Returns:
analyzer_cache: Query result cache stats
ranker_cache: Aggregation and score cache stats
- Return type:
Dictionary with stats for both analyzer and ranker caches
- get_cached_predictions(n: int, rank_metric: str, rank_partition: str = 'val', display_partition: str = 'test', display_metrics: List[str] | None = None, aggregate: str | None = None, aggregate_method: str | None = None, aggregate_exclude_outliers: bool | None = None, group_by: str | List[str] | None = None, aggregate_partitions: bool = True, **filters)[source]
Get predictions with caching support.
This method wraps predictions.top() with a caching layer. Charts should call this method instead of directly calling predictions.top() to benefit from caching.
The cache key includes: aggregate, rank_metric, rank_partition, display_partition, group_by, and all filters.
- Parameters:
n – Number of top predictions to return.
rank_metric – Metric for ranking.
rank_partition – Partition for ranking (default: ‘val’).
display_partition – Partition for display (default: ‘test’).
display_metrics – List of metrics to compute for display.
aggregate – Aggregation column (e.g., ‘ID’) or None.
aggregate_method – Aggregation method (‘mean’, ‘median’, ‘vote’). If None, uses default_aggregate_method from constructor.
aggregate_exclude_outliers – If True, exclude outliers using T² before aggregation. If None, uses default_aggregate_exclude_outliers from constructor.
group_by – Grouping column(s) for deduplication.
aggregate_partitions – If True, include all partition data.
**filters – Additional filter criteria.
- Returns:
PredictionResultsList from cache or fresh computation.
Example
>>> # First call computes and caches >>> preds = analyzer.get_cached_predictions( ... n=5, rank_metric='rmse', aggregate='ID' ... ) >>> # Second call with same params is instant >>> preds = analyzer.get_cached_predictions( ... n=5, rank_metric='rmse', aggregate='ID' ... )
- plot_branch_boxplot(rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot boxplot comparing score distributions across branches.
Creates a boxplot showing the distribution of metric values for each branch.
- Parameters:
rank_metric – Metric for ranking models (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display results from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
figsize – Figure size tuple (default: auto-computed).
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with branch comparison boxplot.
Examples
>>> fig = analyzer.plot_branch_boxplot(display_metric='rmse') >>> fig = analyzer.plot_branch_boxplot( ... display_metric='r2', ... aggregate='ID' ... )
- plot_branch_comparison(rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, show_ci: bool = True, ci_level: float = 0.95, figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot bar chart comparing branch performance with confidence intervals.
Creates a grouped bar chart showing mean metric values for each branch with optional confidence intervals.
- Parameters:
rank_metric – Metric for ranking models (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display results from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
show_ci – If True, show confidence intervals (default: True).
ci_level – Confidence level for intervals (default: 0.95).
figsize – Figure size tuple (default: auto-computed).
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with branch comparison bar chart.
Examples
>>> fig = analyzer.plot_branch_comparison(display_metric='rmse') >>> fig = analyzer.plot_branch_comparison( ... display_metric='r2', ... aggregate='ID', ... show_ci=True ... )
- plot_branch_diagram(show_metrics: bool = True, metric: str | None = None, partition: str = 'test', figsize: tuple | None = None, title: str | None = None, config: Dict[str, Any] | None = None) Figure[source]
Plot DAG diagram showing the branching structure of the pipeline.
Creates a visual diagram showing shared steps, branch nodes, and post-branch models in a hierarchical layout.
- Parameters:
show_metrics – If True, show metrics in branch nodes (default: True).
metric – Metric to display (default: auto-detect).
partition – Partition for metrics (default: ‘test’).
figsize – Figure size tuple (default: auto-computed).
title – Optional title for the diagram.
config – Additional configuration dict for BranchDiagram.
- Returns:
matplotlib Figure with branch DAG diagram.
Examples
>>> fig = analyzer.plot_branch_diagram(metric='rmse') >>> fig = analyzer.plot_branch_diagram( ... show_metrics=True, ... metric='r2', ... partition='val' ... )
- plot_branch_heatmap(y_var: str = 'fold_id', rank_metric: str | None = None, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot heatmap of branch performance across folds or other variable.
Creates a heatmap with branches on x-axis and another variable (e.g., fold_id) on y-axis.
- Parameters:
y_var – Variable for y-axis (default: ‘fold_id’).
rank_metric – Metric for ranking (default: auto-detect).
display_metric – Metric to display (default: same as rank_metric).
display_partition – Partition to display (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column.
config – Optional ChartConfig to override defaults.
**kwargs – Additional parameters passed to plot_heatmap.
- Returns:
matplotlib Figure with branch heatmap.
Examples
>>> fig = analyzer.plot_branch_heatmap(display_metric='rmse') >>> fig = analyzer.plot_branch_heatmap( ... y_var='model_name', ... display_metric='r2' ... )
- plot_candlestick(variable: str, display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot candlestick chart for score distribution by variable.
- Parameters:
variable – Variable to group by (e.g., ‘model_name’, ‘preprocessings’).
display_metric – Metric to analyze (default: auto-detect from task type).
display_partition – Partition to display scores from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object.
Example
>>> fig = analyzer.plot_candlestick('model_name', display_metric='rmse') >>> fig = analyzer.plot_candlestick('model_name', display_metric='rmse', aggregate='ID')
- plot_confusion_matrix(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str | List[str] = '', display_partition: str = 'test', show_scores: bool = True, aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot confusion matrices for top K classification models.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
k – Number of top models to show (default: 5).
rank_metric – Metric for ranking (default: auto-detect from task type).
rank_partition – Partition used for ranking models (default: ‘val’).
display_metric – Metric(s) to display in titles. Can be a single string (e.g., ‘accuracy’) or a list of strings for multiple metrics (e.g., [‘balanced_accuracy’, ‘accuracy’]). Metric names are shown in abbreviated form (default: same as rank_metric).
display_partition – Partition to display confusion matrix from (default: ‘test’).
show_scores – If True, show scores in chart titles (default: True).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_confusion_matrix(k=3, rank_metric='f1') >>> fig = analyzer.plot_confusion_matrix(k=3, aggregate='ID') >>> # Multiple metrics displayed with abbreviated names >>> fig = analyzer.plot_confusion_matrix( ... k=3, ... display_metric=['balanced_accuracy', 'accuracy'] ... )
- plot_heatmap(x_var: str, y_var: str, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'test', normalize: bool = False, rank_agg: str = 'best', display_agg: str = 'best', show_counts: bool = True, local_scale: bool = False, column_scale: bool = False, aggregate: str | None = None, top_k: int | None = None, sort_by_value: bool = False, sort_by: str | None = None, config: ChartConfig | None = None, **kwargs) Figure[source]
Plot performance heatmap across two variables.
For each (x_var, y_var) cell: 1. Rank predictions by rank_metric on rank_partition using rank_agg 2. Display display_metric from display_partition using display_agg 3. Normalize per dataset if requested 4. Show counts if requested
- Parameters:
x_var – Variable for x-axis (e.g., ‘model_name’, ‘preprocessings’).
y_var – Variable for y-axis (e.g., ‘dataset_name’, ‘partition’).
rank_metric – Metric used to rank/select models (default: auto-detect from task type).
rank_partition – Partition used for ranking models (default: ‘val’).
display_metric – Metric to display in heatmap (default: same as rank_metric).
display_partition – Partition to display scores from (default: ‘test’).
normalize – If True, show normalized scores in cells. Colors always use normalized (default: False).
rank_agg – Aggregation for ranking (‘best’, ‘worst’, ‘mean’, ‘median’) (default: ‘best’).
display_agg – Aggregation for display scores (‘best’, ‘worst’, ‘mean’, ‘median’) (default: ‘mean’).
show_counts – Show prediction counts in cells (default: True).
local_scale – If True, colorbar shows actual metric values; if False, shows 0-1 normalized (default: False).
column_scale – If True, normalize colors per column (best in column = 1.0). Automatically sets local_scale=False when enabled (default: False).
aggregate – If provided, aggregate predictions by this metadata column (e.g., ‘ID’).
top_k – If provided, show only top K models. Selection uses Borda count: first keeps top-1 per column, then ranks by Borda count.
sort_by_value – If True, sort Y-axis by ranking score (best first) instead of alphabetically. Uses rank_metric on rank_partition. Deprecated: use sort_by=’value’ instead.
sort_by – Sorting method for Y-axis (rows). Options: - None: Alphabetical sorting (default). - ‘value’: Sort by ranking score on rank_partition column. - ‘mean’: Sort by mean score across all columns. - ‘median’: Sort by median score across all columns. - ‘borda’: Sort by Borda count (sum of ranks across columns). - ‘condorcet’: Sort by pairwise wins (Copeland method). - ‘consensus’: Sort by consensus (geometric mean of normalized ranks).
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional filters (dataset_name, model_name, etc.).
- Returns:
matplotlib Figure object.
Example
>>> # Rank on best val RMSE, display mean test RMSE >>> fig = analyzer.plot_heatmap('model_name', 'dataset_name') >>> >>> # Rank on mean val R2, display best test F1 >>> fig = analyzer.plot_heatmap( ... 'model_name', 'dataset_name', ... rank_metric='r2', ... rank_agg='mean', ... display_metric='f1', ... display_agg='best' ... ) >>> >>> # Use column normalization for comparing across partitions >>> fig = analyzer.plot_heatmap( ... 'partition', 'model_name', ... column_scale=True ... )
- plot_histogram(display_metric: str | None = None, display_partition: str = 'test', aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot score distribution histogram.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
display_metric – Metric to plot (default: auto-detect from task type).
display_partition – Partition to display scores from (default: ‘test’).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, bins, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_histogram(display_metric='r2', display_partition='val') >>> fig = analyzer.plot_histogram(display_metric='rmse', aggregate='ID')
- plot_nested_branches(level1_var: str = 'branch_path_level1', level2_var: str = 'branch_path_level2', metric: str | None = None, partition: str = 'test', plot_type: str = 'grouped_bar', figsize: tuple | None = None, config: ChartConfig | None = None, **filters) Figure[source]
Plot nested branch comparison for hierarchical experiments.
Creates grouped bar charts or faceted plots for nested branch structures.
- Parameters:
level1_var – Variable for first level grouping (outer group).
level2_var – Variable for second level grouping (inner group/x-axis).
metric – Metric to display (default: auto-detect).
partition – Partition for metrics (default: ‘test’).
plot_type – Type of plot (‘grouped_bar’, ‘facet’).
figsize – Figure size tuple.
config – Optional ChartConfig to override defaults.
**filters – Additional filter criteria.
- Returns:
matplotlib Figure with nested branch visualization.
Examples
>>> # Compare outlier strategies × preprocessing >>> fig = analyzer.plot_nested_branches( ... level1_var='outlier_strategy', ... level2_var='preprocessing', ... metric='rmse' ... )
- plot_top_k(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'all', show_scores: bool = True, aggregate: str | None = None, config: ChartConfig | None = None, **kwargs) Figure | List[Figure][source]
Plot top K model comparison (scatter + residuals).
Models are ranked by rank_metric on rank_partition, then predictions from display_partition(s) are shown.
When multiple datasets are present and no dataset_name is specified, creates one figure per dataset.
- Parameters:
k – Number of top models to show (default: 5).
rank_metric – Metric for ranking models (default: auto-detect from task type).
rank_partition – Partition used for ranking (default: ‘val’).
display_metric – Metric to display in titles (default: same as rank_metric).
display_partition – Partition(s) to display (‘all’ or specific partition).
show_scores – If True, show scores in chart titles (default: True).
aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.
config – Optional ChartConfig to override analyzer’s default config for this chart.
**kwargs – Additional parameters (dataset_name, figsize, filters).
- Returns:
matplotlib Figure object or list of Figure objects (one per dataset).
Example
>>> fig = analyzer.plot_top_k(k=3, rank_metric='r2') >>> fig = analyzer.plot_top_k(k=3, aggregate='ID') # Aggregated by ID
- nirs4all.visualization.plot_branch_diagram(predictions: Any = None, show_metrics: bool = True, metric: str = 'rmse', partition: str = 'test', figsize: Tuple[float, float] | None = None, title: str | None = None, config: Dict[str, Any] | None = None) Figure[source]
DEPRECATED: Use plot_pipeline_diagram instead.
- Parameters:
predictions – Predictions object with branch metadata.
show_metrics – Whether to show metrics in nodes.
metric – Metric to display (default: ‘rmse’).
partition – Partition for metrics (default: ‘test’).
figsize – Figure size tuple.
title – Optional title for the diagram.
config – Additional configuration dict.
- Returns:
matplotlib Figure object.
- nirs4all.visualization.plot_pipeline_diagram(pipeline_steps: List[Any] | None = None, predictions: Any = None, show_shapes: bool = True, figsize: Tuple[float, float] | None = None, title: str | None = None, initial_shape: Tuple[int, int, int] | None = None, config: Dict[str, Any] | None = None, execution_trace: Any = None) Figure[source]
Convenience function to create a pipeline diagram.
- Parameters:
pipeline_steps – List of pipeline step definitions
predictions – Optional Predictions object with execution data
show_shapes – Whether to show shape info in nodes
figsize – Figure size tuple
title – Optional title for the diagram
initial_shape – Initial dataset shape (samples, processings, features)
config – Additional configuration dict
execution_trace – Optional ExecutionTrace object
- Returns:
matplotlib Figure object
Example
>>> from nirs4all.visualization.pipeline_diagram import plot_pipeline_diagram >>> fig = plot_pipeline_diagram(pipeline, initial_shape=(189, 1, 2151)) >>> fig.savefig('pipeline_diagram.png')