Source code for nirs4all.controllers.charts.augmentation

"""AugmentationChartController - Visualizes augmentation effects on spectra."""

from typing import Any, Dict, List, Tuple, TYPE_CHECKING
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import numpy as np
import io
from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.utils.header_units import get_x_values_and_label, apply_x_axis_limits

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] @register_controller class AugmentationChartController(OperatorController): """ Controller for visualizing augmentation effects on spectra. Supports two visualization modes: 1. augment_chart: Shows original vs augmented samples overlaid with different colors 2. augment_details_chart: Shows a grid with raw data and each augmentation type separately """ priority = 10
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: return keyword in ["augment_chart", "augmentation_chart", "augment_details_chart", "augmentation_details_chart"]
[docs] @classmethod def use_multi_source(cls) -> bool: return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """Chart controllers should skip execution during prediction mode.""" return False
[docs] def execute( self, step_info: 'ParsedStep', dataset: 'SpectroDataset', context: 'ExecutionContext', runtime_context: Any, source: int = -1, mode: str = "train", loaded_binaries: Any = None, prediction_store: Any = None ) -> Tuple['ExecutionContext', Any]: """ Execute augmentation visualization. Returns: Tuple of (context, StepOutput) """ from nirs4all.pipeline.execution.result import StepOutput # Extract step config for compatibility step = step_info.original_step keyword = context.metadata.keyword # Skip execution in prediction mode if mode == "predict" or mode == "explain": return context, StepOutput() is_details = keyword in ["augment_details_chart", "augmentation_details_chart"] # Get configuration from step if it's a dict config = {} if isinstance(step, dict): config = step.get(keyword, {}) alpha_original = config.get("alpha_original", 0.8) alpha_augmented = config.get("alpha_augmented", 0.4) max_samples = config.get("max_samples", 50) # Limit samples for readability # Initialize image list img_list = [] # Get train context train_context = context.with_partition("train") # Get spectra data for visualization (use first source, first processing) spectra_data = dataset.x(train_context.selector, "3d", False) if not isinstance(spectra_data, list): spectra_data = [spectra_data] for sd_idx, x in enumerate(spectra_data): # Get base and augmented sample indices base_indices = dataset._indexer.x_indices(train_context.selector, include_augmented=False) all_indices = dataset._indexer.x_indices(train_context.selector, include_augmented=True) # Separate augmented indices base_set = set(base_indices.tolist() if hasattr(base_indices, 'tolist') else list(base_indices)) augmented_indices = [idx for idx in all_indices if idx not in base_set] n_base = len(base_indices) n_augmented = len(augmented_indices) if runtime_context.step_runner.verbose > 0: logger.debug(f" Source {sd_idx}: {n_base} base samples, {n_augmented} augmented samples") # Get first processing (raw or first preprocessed) processing_ids = dataset.features_processings(sd_idx) if is_details: # Details mode: show raw + each augmentation type fig = self._create_details_chart( x, base_indices, augmented_indices, all_indices, processing_ids, dataset, sd_idx, alpha_original, alpha_augmented, max_samples ) image_name = "Augmentation_Details_Chart" else: # Overlay mode: show original vs augmented overlaid fig = self._create_overlay_chart( x, base_indices, augmented_indices, all_indices, processing_ids, dataset, sd_idx, alpha_original, alpha_augmented, max_samples ) image_name = "Augmentation_Chart" if dataset.is_multi_source(): image_name += f"_src{sd_idx}" # Save plot to memory buffer img_buffer = io.BytesIO() fig.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight') img_buffer.seek(0) img_png_binary = img_buffer.getvalue() img_buffer.close() img_list.append((img_png_binary, image_name, "png")) if runtime_context.step_runner.plots_visible: runtime_context.step_runner._figure_refs.append(fig) plt.show() else: plt.close(fig) return context, StepOutput(outputs=img_list)
def _create_overlay_chart( self, x: np.ndarray, base_indices: np.ndarray, augmented_indices: List[int], all_indices: np.ndarray, processing_ids: List[str], dataset: 'SpectroDataset', source_idx: int, alpha_original: float, alpha_augmented: float, max_samples: int ) -> Figure: """ Create overlay chart showing original (blue) and augmented (orange) samples. """ n_processings = x.shape[1] # Calculate subplot grid n_cols = min(3, n_processings) n_rows = (n_processings + n_cols - 1) // n_cols fig_width = 6 * n_cols fig_height = 5 * n_rows fig = plt.figure(figsize=(fig_width, fig_height)) main_title = f"{dataset.name} - Augmentation Overlay" if dataset.is_multi_source(): main_title += f" (Source {source_idx})" fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.98) # Map sample indices to array positions all_indices_list = all_indices.tolist() if hasattr(all_indices, 'tolist') else list(all_indices) idx_to_pos = {idx: pos for pos, idx in enumerate(all_indices_list)} # Limit samples for readability base_indices_list = base_indices.tolist() if hasattr(base_indices, 'tolist') else list(base_indices) if len(base_indices_list) > max_samples: np.random.seed(42) base_indices_list = list(np.random.choice(base_indices_list, max_samples, replace=False)) if len(augmented_indices) > max_samples: np.random.seed(42) augmented_indices = list(np.random.choice(augmented_indices, max_samples, replace=False)) # Get headers spectra_headers = dataset.headers(source_idx) try: header_unit = dataset.header_unit(source_idx) except (AttributeError, IndexError): header_unit = "cm-1" for processing_idx in range(n_processings): processing_name = self._shorten_processing_name(processing_ids[processing_idx]) ax = fig.add_subplot(n_rows, n_cols, processing_idx + 1) # Get 2D data for this processing x_2d = x[:, processing_idx, :] n_features = x_2d.shape[1] # Determine x-axis values if spectra_headers and len(spectra_headers) == n_features: try: x_values = np.array([float(h) for h in spectra_headers]) x_label = 'Wavenumber (cm⁻¹)' if header_unit == "cm-1" else 'Wavelength (nm)' if header_unit == "nm" else 'Features' except (ValueError, TypeError): x_values = np.arange(n_features) x_label = 'Features' else: x_values = np.arange(n_features) x_label = 'Features' # Plot original samples (blue) for idx in base_indices_list: if idx in idx_to_pos: pos = idx_to_pos[idx] ax.plot(x_values, x_2d[pos], color='steelblue', alpha=alpha_original, linewidth=0.8) # Plot augmented samples (orange) for idx in augmented_indices: if idx in idx_to_pos: pos = idx_to_pos[idx] ax.plot(x_values, x_2d[pos], color='darkorange', alpha=alpha_augmented, linewidth=0.8) # Force axis order if len(x_values) > 1 and x_values[0] > x_values[-1]: ax.set_xlim(x_values[0], x_values[-1]) ax.set_xlabel(x_label, fontsize=9) ax.set_ylabel('Intensity', fontsize=9) subtitle = f"{processing_name}" ax.set_title(subtitle, fontsize=10) # Add legend only to first subplot if processing_idx == 0: from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], color='steelblue', linewidth=2, label=f'Original ({len(base_indices_list)})'), Line2D([0], [0], color='darkorange', linewidth=2, label=f'Augmented ({len(augmented_indices)})') ] ax.legend(handles=legend_elements, loc='upper right', fontsize=8) plt.tight_layout(rect=(0, 0, 1, 0.92), h_pad=4.0) # type: ignore[arg-type] return fig def _create_details_chart( self, x: np.ndarray, base_indices: np.ndarray, augmented_indices: List[int], all_indices: np.ndarray, processing_ids: List[str], dataset: 'SpectroDataset', source_idx: int, alpha_original: float, alpha_augmented: float, max_samples: int ) -> Figure: """ Create details chart showing raw on top-left, then each augmented transformation separately. This groups augmented samples by their transformer (via origin mapping and metadata). """ _n_processings = x.shape[1] # Available for future use # Get augmentation info from dataset indexer # Group augmented samples by their transformer transformer_groups = self._group_augmented_by_transformer(dataset, augmented_indices) # Calculate layout: raw + one subplot per transformer n_transformers = len(transformer_groups) n_subplots = 1 + n_transformers # Raw + each transformer # For details, we show first processing only but multiple transformers n_cols = min(3, n_subplots) n_rows = (n_subplots + n_cols - 1) // n_cols fig_width = 6 * n_cols fig_height = 5 * n_rows fig = plt.figure(figsize=(fig_width, fig_height)) main_title = f"{dataset.name} - Augmentation Details" if dataset.is_multi_source(): main_title += f" (Source {source_idx})" fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.98) # Map sample indices to array positions all_indices_list = all_indices.tolist() if hasattr(all_indices, 'tolist') else list(all_indices) idx_to_pos = {idx: pos for pos, idx in enumerate(all_indices_list)} # Get first processing data (or raw) x_2d = x[:, 0, :] # First processing n_features = x_2d.shape[1] # Get headers and determine x-axis values using centralized utility spectra_headers = dataset.headers(source_idx) try: header_unit = dataset.header_unit(source_idx) except (AttributeError, IndexError): header_unit = "cm-1" x_values, x_label = get_x_values_and_label(spectra_headers, header_unit, n_features) base_indices_list = base_indices.tolist() if hasattr(base_indices, 'tolist') else list(base_indices) # Limit samples if len(base_indices_list) > max_samples: np.random.seed(42) base_indices_list = list(np.random.choice(base_indices_list, max_samples, replace=False)) # Plot 1: Raw/Original samples only ax1 = fig.add_subplot(n_rows, n_cols, 1) for idx in base_indices_list: if idx in idx_to_pos: pos = idx_to_pos[idx] ax1.plot(x_values, x_2d[pos], color='steelblue', alpha=alpha_original, linewidth=0.8) apply_x_axis_limits(ax1, x_values) ax1.set_xlabel(x_label, fontsize=9) ax1.set_ylabel('Intensity', fontsize=9) processing_name = self._shorten_processing_name(processing_ids[0]) ax1.set_title(f"Original ({len(base_indices_list)} samples) - {processing_name}", fontsize=10) # Get colormap for different transformers cmap = plt.colormaps['Set2'] colors = cmap(np.linspace(0, 1, max(n_transformers, 1))) # Plot each transformer group for t_idx, (transformer_name, aug_indices) in enumerate(transformer_groups.items()): ax = fig.add_subplot(n_rows, n_cols, t_idx + 2) # Plot original in background (light gray) for idx in base_indices_list: if idx in idx_to_pos: pos = idx_to_pos[idx] ax.plot(x_values, x_2d[pos], color='lightgray', alpha=0.3, linewidth=0.5) # Limit augmented samples aug_indices_limited = aug_indices if len(aug_indices) > max_samples: np.random.seed(42 + t_idx) aug_indices_limited = list(np.random.choice(aug_indices, max_samples, replace=False)) # Plot augmented samples for this transformer for idx in aug_indices_limited: if idx in idx_to_pos: pos = idx_to_pos[idx] ax.plot(x_values, x_2d[pos], color=colors[t_idx], alpha=alpha_augmented, linewidth=0.8) apply_x_axis_limits(ax, x_values) ax.set_xlabel(x_label, fontsize=9) ax.set_ylabel('Intensity', fontsize=9) ax.set_title(f"{transformer_name} ({len(aug_indices)} samples)", fontsize=10) plt.tight_layout(rect=(0, 0, 1, 0.95), h_pad=3.0) # type: ignore[arg-type] return fig def _group_augmented_by_transformer( self, dataset: 'SpectroDataset', augmented_indices: List[int] ) -> Dict[str, List[int]]: """ Group augmented samples by their transformer type. Uses the 'augmentation' column in the indexer to identify transformer types. Returns a dict: {transformer_name: [sample_indices]} """ import polars as pl groups: Dict[str, List[int]] = {} if not augmented_indices: return groups # Get augmentation info from the indexer DataFrame df = dataset._indexer.df # noqa: SLF001 # Filter to only augmented samples aug_df = df.filter(pl.col("sample").is_in(augmented_indices)) # Group by augmentation type for row in aug_df.iter_rows(named=True): sample_id = row["sample"] aug_type = row.get("augmentation", None) if aug_type is None: transformer_name = "Augmented" else: # Extract class name from augmentation ID if it contains class info transformer_name = str(aug_type) # Shorten common names if "Rotate_Translate" in transformer_name: transformer_name = "Rotate_Translate" elif "Spline" in transformer_name: # Extract spline type if "Y_Perturbations" in transformer_name: transformer_name = "Spline_Y" elif "X_Perturbations" in transformer_name: transformer_name = "Spline_X" elif "Curve" in transformer_name: transformer_name = "Spline_Curve" elif "Simplification" in transformer_name: transformer_name = "Spline_Simplify" else: transformer_name = "Spline" elif "Random" in transformer_name: transformer_name = "Random_Op" if transformer_name not in groups: groups[transformer_name] = [] groups[transformer_name].append(sample_id) # If no groups found, put all in one group if not groups and augmented_indices: groups["Augmented"] = augmented_indices return groups @staticmethod def _shorten_processing_name(processing_name: str) -> str: """Shorten preprocessing names for chart titles.""" import re replacements = [ ("raw_", ""), ("SavitzkyGolay", "SG"), ("MultiplicativeScatterCorrection", "MSC"), ("StandardNormalVariate", "SNV"), ("FirstDerivative", "1stDer"), ("SecondDerivative", "2ndDer"), ("Detrend", "Detr"), ("Gaussian", "Gauss"), ("Haar", "Haar"), ("LogTransform", "Log"), ("MinMaxScaler", "MinMax"), ("RobustScaler", "Rbt"), ("StandardScaler", "Std"), ("QuantileTransformer", "Quant"), ("PowerTransformer", "Pow"), ] for long, short in replacements: processing_name = processing_name.replace(long, short) processing_name = re.sub(r'_\d+_', '>', processing_name) processing_name = re.sub(r'_\d+', '', processing_name) return processing_name