nirs4all.visualization.charts package

Submodules

Module contents

Chart classes for prediction visualization.

class nirs4all.visualization.charts.BaseChart(predictions, dataset_name_override: str | None = None, config: ChartConfig | Dict[str, Any] | None = None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: ABC

Abstract base class for all prediction visualization charts.

Provides common interface and shared functionality for chart implementations. Each chart type should inherit from this class and implement required methods.

Charts can be initialized in two modes: 1. With predictions only (legacy): Direct access to Predictions object 2. With analyzer (recommended): Access through PredictionAnalyzer with caching

When an analyzer is provided, charts use analyzer.get_cached_predictions() to benefit from caching of expensive aggregation operations.

Designed to be operator-ready for future integration with the controller/operator pattern (see SpectraChartController for reference pattern).

predictions

Predictions object containing prediction data.

analyzer

Optional PredictionAnalyzer for cached data access.

dataset_name_override

Optional dataset name override for display.

config

ChartConfig instance for customization.

abstractmethod render(**kwargs) Figure[source]

Render the chart and return matplotlib Figure.

This method must be implemented by all chart subclasses.

Parameters:

**kwargs – Chart-specific rendering parameters.

Returns:

matplotlib Figure object.

Raises:

NotImplementedError – If not implemented by subclass.

abstractmethod validate_inputs(**kwargs) None[source]

Validate input parameters for the chart.

This method should be called before rendering to ensure all required parameters are present and valid.

Parameters:

**kwargs – Chart-specific parameters to validate.

Raises:

ValueError – If validation fails.

class nirs4all.visualization.charts.CandlestickChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Candlestick/box plot for score distributions by variable.

Shows score distribution statistics (min, Q25, mean, Q75, max) for each value of a grouping variable.

render(variable: str, display_metric: str | None = None, display_partition: str = 'test', dataset_name: str | None = None, figsize: tuple | None = None, aggregate: str | None = None, clip_outliers: bool = True, iqr_factor: float = 1.5, **filters) Figure[source]

Render candlestick chart showing metric distribution by variable (Optimized with Polars).

Parameters:
  • variable – Variable to group by (e.g., ‘model_name’, ‘preprocessings’).

  • display_metric – Metric to analyze (default: auto-detect from task type).

  • display_partition – Partition to display scores from (default: ‘test’).

  • dataset_name – Optional dataset filter.

  • figsize – Figure size tuple (default: from config).

  • aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.

  • clip_outliers – If True, constrain the y-axis to show the main distribution and let extreme outliers go off-frame (default: True).

  • iqr_factor – Factor to multiply IQR for determining outlier bounds. Higher values show more of the tails (default: 1.5).

  • **filters – Additional filters (config_name, etc.).

Returns:

matplotlib Figure object.

validate_inputs(variable: str, display_metric: str | None, **kwargs) None[source]

Validate candlestick inputs.

Parameters:
  • variable – Variable name to group by.

  • display_metric – Metric name to analyze.

  • **kwargs – Additional parameters (ignored).

Raises:

ValueError – If variable or display_metric is invalid.

class nirs4all.visualization.charts.ChartConfig(colormap: str = 'viridis', heatmap_colormap: str = 'RdYlGn', partition_colors: Dict[str, str] | None = None, font_family: str | None = None, title_fontsize: int = 14, label_fontsize: int = 10, tick_fontsize: int = 9, legend_fontsize: int = 9, annotation_fontsize: int = 9, figsize_small: tuple = (10, 4), figsize_medium: tuple = (12, 8), figsize_large: tuple = (16, 10), dpi: int = 300, alpha: float = 0.7)[source]

Bases: object

Configuration for chart appearance and behavior.

Provides customization options for colors, fonts, and figure sizes. All parameters have sensible defaults for seamless usage.

colormap

Matplotlib colormap name for gradients (default: ‘viridis’).

Type:

str

heatmap_colormap

Colormap for heatmaps (default: ‘RdYlGn’).

Type:

str

partition_colors

Dict mapping partition names to colors.

Type:

Dict[str, str] | None

font_family

Font family for all text (default: matplotlib default).

Type:

str | None

title_fontsize

Font size for titles (default: 14).

Type:

int

label_fontsize

Font size for axis labels (default: 10).

Type:

int

tick_fontsize

Font size for tick labels (default: 9).

Type:

int

legend_fontsize

Font size for legend text (default: 9).

Type:

int

annotation_fontsize

Font size for text annotations inside charts (default: 9).

Type:

int

figsize_small

Small figure size (default: (10, 6)).

Type:

tuple

figsize_medium

Medium figure size (default: (12, 8)).

Type:

tuple

figsize_large

Large figure size (default: (16, 10)).

Type:

tuple

dpi

Output resolution (default: 300).

Type:

int

alpha

Default alpha for plot elements (default: 0.7).

Type:

float

__post_init__()[source]

Initialize default partition colors if not provided.

alpha: float = 0.7
annotation_fontsize: int = 9
apply_font_settings() None[source]

Apply font settings to matplotlib rcParams.

colormap: str = 'viridis'
dpi: int = 300
figsize_large: tuple = (16, 10)
figsize_medium: tuple = (12, 8)
figsize_small: tuple = (10, 4)
font_family: str | None = None
get_figsize(size: str = 'medium') tuple[source]

Get figure size by name.

Parameters:

size – Size name (‘small’, ‘medium’, ‘large’).

Returns:

Tuple of (width, height).

heatmap_colormap: str = 'RdYlGn'
label_fontsize: int = 10
legend_fontsize: int = 9
partition_colors: Dict[str, str] | None = None
tick_fontsize: int = 9
title_fontsize: int = 14
class nirs4all.visualization.charts.ConfusionMatrixChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Confusion matrix visualizations for classification models.

Displays confusion matrices for top K classification models, with proper handling of multi-class predictions.

render(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str | List[str] = '', display_partition: str = 'test', show_scores: bool = True, dataset_name: str | None = None, figsize: tuple | None = None, aggregate: str | None = None, **filters) Figure | List[Figure][source]

Plot confusion matrices for top K classification models per dataset.

Models are ranked by the metric on rank_partition, then confusion matrices are displayed using predictions from display_partition. Returns one figure per dataset to avoid mixing predictions from different datasets.

Parameters:
  • k – Number of top models to show per dataset (default: 5).

  • rank_metric – Metric for ranking (default: auto-detect from task type).

  • rank_partition – Partition used for ranking models (default: ‘val’).

  • display_metric – Metric(s) to display in titles. Can be a string for single metric or a list of strings for multiple metrics (e.g., [‘balanced_accuracy’, ‘accuracy’]). Default: same as rank_metric.

  • display_partition – Partition to display confusion matrix from (default: ‘test’).

  • show_scores – If True, show scores in chart titles (default: True).

  • dataset_name – Optional dataset filter. If provided, only shows that dataset.

  • figsize – Figure size tuple (default: from config).

  • aggregate – If provided, aggregate predictions by this metadata column or ‘y’.

  • **filters – Additional filters (e.g., config_name=”config1”).

Returns:

Single Figure if one dataset, List[Figure] if multiple datasets.

validate_inputs(k: int, rank_metric: str | None, **kwargs) None[source]

Validate confusion matrix inputs.

Parameters:
  • k – Number of top models.

  • rank_metric – Metric name.

  • **kwargs – Additional parameters (ignored).

Raises:

ValueError – If inputs are invalid.

class nirs4all.visualization.charts.HeatmapChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Heatmap visualization of performance across two variables.

Supports flexible ranking and display configurations with multiple aggregation strategies.

render(x_var: str, y_var: str, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'test', figsize: tuple | None = None, normalize: bool = False, rank_agg: str = 'best', display_agg: str = 'mean', show_counts: bool = True, local_scale: bool = False, column_scale: bool = False, aggregate: str | None = None, top_k: int | None = None, sort_by_value: bool = False, sort_by: str | None = None, **filters) Figure[source]

Render performance heatmap (Optimized with Polars).

Uses vectorized operations for 20x+ speedup. When aggregate is provided, uses the slower but accurate aggregation path.

Parameters:
  • x_var – Variable for X-axis (columns).

  • y_var – Variable for Y-axis (rows).

  • rank_metric – Metric used for ranking models.

  • rank_partition – Partition used for ranking (‘val’, ‘test’, ‘train’).

  • display_metric – Metric displayed in cells.

  • display_partition – Partition for display metric.

  • figsize – Figure size (auto-computed if None).

  • normalize – Whether to normalize displayed values.

  • rank_agg – Ranking aggregation (‘best’, ‘worst’, ‘mean’, ‘median’).

  • display_agg – Display aggregation strategy.

  • show_counts – Whether to show sample counts.

  • local_scale – If True, use local scale for colors.

  • column_scale – If True, normalize colors per column (best in column = 1.0). Automatically sets local_scale=False when enabled.

  • aggregate – Aggregation column for sample-level aggregation.

  • top_k – If provided, show only top K models. Selection uses Borda count: first keeps top-1 per column, then ranks by Borda count.

  • sort_by_value – If True, sort Y-axis by ranking score (best first) instead of alphabetically. Uses rank_metric on rank_partition. Deprecated: use sort_by=’value’ instead.

  • sort_by – Sorting method for Y-axis (rows). Options: - None: Alphabetical sorting (default). - ‘value’: Sort by ranking score on rank_partition column. - ‘mean’: Sort by mean score across all columns. - ‘median’: Sort by median score across all columns. - ‘borda’: Sort by Borda count (sum of ranks across columns). - ‘condorcet’: Sort by pairwise wins (Copeland method). - ‘consensus’: Sort by consensus (geometric mean of normalized ranks).

  • **filters – Additional filters for predictions.

Returns:

Matplotlib Figure with the heatmap.

validate_inputs(x_var: str, y_var: str, rank_metric: str, **kwargs) None[source]

Validate inputs.

class nirs4all.visualization.charts.ScoreHistogramChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Histogram of score distributions.

Displays distribution of a metric across predictions with statistical annotations.

render(display_metric: str | None = None, display_partition: str = 'test', dataset_name: str | None = None, bins: int = 20, figsize: tuple | None = None, aggregate: str | None = None, clip_outliers: bool = True, iqr_factor: float = 1.5, layout: Literal['standard', 'stacked', 'staggered'] = 'standard', **filters) Figure[source]

Render score distribution histogram (Optimized with Polars).

Parameters:
  • display_metric – Metric to plot (default: auto-detect from task type).

  • display_partition – Partition to display scores from (default: ‘test’).

  • bins – Number of histogram bins (default: 20).

  • figsize – Figure size tuple (default: from config).

  • aggregate – If provided, aggregate predictions by this metadata column or ‘y’. When ‘y’, groups by y_true values. When a column name (e.g., ‘ID’), groups by that metadata column. Aggregated predictions have recalculated metrics.

  • clip_outliers – If True, constrain the x-axis to show the main distribution and let extreme outliers go off-frame (default: True).

  • iqr_factor – Factor to multiply IQR for determining outlier bounds. Higher values show more of the tails (default: 1.5).

  • layout – Histogram layout style: - ‘standard’: overlapping histograms (default) - ‘stacked’: bars stacked on top of each other - ‘staggered’: bars placed side by side

  • dataset_name – Optional dataset filter.

  • **filters – Additional filters (model_name, config_name, etc.).

Returns:

matplotlib Figure object.

validate_inputs(display_metric: str | None, **kwargs) None[source]

Validate histogram inputs.

Parameters:
  • display_metric – Metric name to plot.

  • **kwargs – Additional parameters (ignored).

Raises:

ValueError – If display_metric is invalid.

class nirs4all.visualization.charts.TopKComparisonChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Scatter plots comparing predicted vs observed values for top K models.

Displays predicted vs true scatter plots alongside residual plots for the best performing models according to a ranking metric.

render(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str = '', display_partition: str = 'all', show_scores: bool = True, dataset_name: str | None = None, figsize: tuple | None = None, aggregate: str | None = None, **filters) Figure[source]

Plot top K models with predicted vs true and residuals.

Uses the top() method to rank models by a metric on rank_partition, then displays predictions from display_partition(s).

Parameters:
  • k – Number of top models to show (default: 5).

  • rank_metric – Metric for ranking models (default: auto-detect from task type).

  • rank_partition – Partition used for ranking (default: ‘val’).

  • display_metric – Metric to display in titles (default: same as rank_metric).

  • display_partition – Partition(s) to display (‘all’ for train/val/test, or ‘test’, ‘val’, ‘train’).

  • show_scores – If True, show scores in chart titles (default: True).

  • dataset_name – Optional dataset filter.

  • figsize – Figure size tuple (default: from config).

  • aggregate – If provided, aggregate predictions by this metadata column or ‘y’.

  • **filters – Additional filters.

Returns:

matplotlib Figure object.

validate_inputs(k: int, rank_metric: str | None, **kwargs) None[source]

Validate top K comparison inputs.

Parameters:
  • k – Number of top models.

  • rank_metric – Metric name for ranking.

  • **kwargs – Additional parameters (ignored).

Raises:

ValueError – If inputs are invalid.