nirs4all.visualization.charts.confusion_matrix module

ConfusionMatrixChart - Confusion matrix visualizations for classification models.

class nirs4all.visualization.charts.confusion_matrix.ConfusionMatrixChart(predictions, dataset_name_override: str | None = None, config=None, analyzer: PredictionAnalyzer | None = None)[source]

Bases: BaseChart

Confusion matrix visualizations for classification models.

Displays confusion matrices for top K classification models, with proper handling of multi-class predictions.

render(k: int = 5, rank_metric: str | None = None, rank_partition: str = 'val', display_metric: str | List[str] = '', display_partition: str = 'test', show_scores: bool = True, dataset_name: str | None = None, figsize: tuple | None = None, aggregate: str | None = None, **filters) Figure | List[Figure][source]

Plot confusion matrices for top K classification models per dataset.

Models are ranked by the metric on rank_partition, then confusion matrices are displayed using predictions from display_partition. Returns one figure per dataset to avoid mixing predictions from different datasets.

Parameters:
  • k – Number of top models to show per dataset (default: 5).

  • rank_metric – Metric for ranking (default: auto-detect from task type).

  • rank_partition – Partition used for ranking models (default: ‘val’).

  • display_metric – Metric(s) to display in titles. Can be a string for single metric or a list of strings for multiple metrics (e.g., [‘balanced_accuracy’, ‘accuracy’]). Default: same as rank_metric.

  • display_partition – Partition to display confusion matrix from (default: ‘test’).

  • show_scores – If True, show scores in chart titles (default: True).

  • dataset_name – Optional dataset filter. If provided, only shows that dataset.

  • figsize – Figure size tuple (default: from config).

  • aggregate – If provided, aggregate predictions by this metadata column or ‘y’.

  • **filters – Additional filters (e.g., config_name=”config1”).

Returns:

Single Figure if one dataset, List[Figure] if multiple datasets.

validate_inputs(k: int, rank_metric: str | None, **kwargs) None[source]

Validate confusion matrix inputs.

Parameters:
  • k – Number of top models.

  • rank_metric – Metric name.

  • **kwargs – Additional parameters (ignored).

Raises:

ValueError – If inputs are invalid.