Source code for nirs4all.visualization.chart_utils.annotator

"""
ChartAnnotator - Helper for adding annotations to charts.
"""
import numpy as np
from typing import Optional, List, Dict, Any, Union
from nirs4all.visualization.charts.config import ChartConfig


[docs] class ChartAnnotator: """Helper for adding annotations to charts. Centralizes text formatting, positioning, and color selection for chart annotations. Uses ChartConfig for styling. Attributes: config: ChartConfig instance for customization. """ def __init__(self, config: Optional[Union[ChartConfig, Dict[str, Any]]] = None): """Initialize annotator with config. Args: config: Optional ChartConfig or dict for customization. If a dict is provided, it will be used to create a ChartConfig. """ if isinstance(config, dict): self.config = ChartConfig(**config) else: self.config = config or ChartConfig()
[docs] def add_heatmap_annotations( self, ax, matrix: np.ndarray, normalized_matrix: np.ndarray, count_matrix: np.ndarray, x_labels: List, y_labels: List, show_counts: bool = True, precision: int = 3 ) -> None: """Add text annotations to heatmap cells. Args: ax: Matplotlib axes object. matrix: Original score matrix. normalized_matrix: Normalized matrix for color selection. count_matrix: Matrix of sample counts. x_labels: List of x-axis labels. y_labels: List of y-axis labels. show_counts: Whether to show sample counts. precision: Number of decimal places for scores. """ for i in range(len(y_labels)): for j in range(len(x_labels)): value = matrix[i, j] if not np.isnan(value): normalized_value = normalized_matrix[i, j] text_color = self.get_text_color(normalized_value) # Format score text score_text = f'{value:.{precision}f}' # Add count if requested if show_counts and count_matrix[i, j] > 1: score_text += f'\n(n={int(count_matrix[i, j])})' ax.text(j, i, score_text, ha='center', va='center', color=text_color, fontsize=self.config.annotation_fontsize)
[docs] @staticmethod def get_text_color(background_value: float, threshold: float = 0.5) -> str: """Determine text color based on background for optimal contrast. Args: background_value: Normalized background value (0-1). threshold: Threshold for switching from white to black text. Returns: Color string (always 'black' for consistency). """ # Always return black for consistent, professional appearance return 'black'
[docs] def add_statistics_box( self, ax, values: List[float], position: str = 'upper right', precision: int = 4 ) -> None: """Add statistics text box to plot. Args: ax: Matplotlib axes object. values: List of values to compute statistics from. position: Position string for text box placement. precision: Number of decimal places. """ if not values: return mean_val = float(np.mean(values)) median_val = float(np.median(values)) std_val = float(np.std(values)) min_val = float(np.min(values)) max_val = float(np.max(values)) stats_text = ( f'n={len(values)}\n' f'μ={mean_val:.{precision}f}\n' f'σ={std_val:.{precision}f}\n' f'min={min_val:.{precision}f}\n' f'max={max_val:.{precision}f}' ) # Determine position coordinates pos_map = { 'upper right': (0.98, 0.98), 'upper left': (0.02, 0.98), 'lower right': (0.98, 0.02), 'lower left': (0.02, 0.02) } x, y = pos_map.get(position, (0.98, 0.98)) ha = 'right' if x > 0.5 else 'left' va = 'top' if y > 0.5 else 'bottom' ax.text(x, y, stats_text, transform=ax.transAxes, verticalalignment=va, horizontalalignment=ha, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), fontsize=self.config.annotation_fontsize)