"""
Centralized header unit utilities for consistent handling across the codebase.
This module provides a single source of truth for:
- Axis labels based on header unit type
- X-values extraction from headers
- Axis orientation (for wavenumber inversion)
All visualization code should use these utilities instead of inline logic.
"""
from typing import List, Optional, Tuple, Union
import numpy as np
from nirs4all.data._features import HeaderUnit, normalize_header_unit
# Canonical axis labels - single source of truth
AXIS_LABELS = {
HeaderUnit.WAVENUMBER: "Wavenumber (cm⁻¹)",
HeaderUnit.WAVELENGTH: "Wavelength (nm)",
HeaderUnit.NONE: "Feature Index",
HeaderUnit.TEXT: "Features",
HeaderUnit.INDEX: "Feature Index",
}
# Default label when unit unknown or invalid
DEFAULT_AXIS_LABEL = "Features"
[docs]
def get_axis_label(unit: Union[str, HeaderUnit]) -> str:
"""Get the appropriate axis label for a given unit type.
Args:
unit: Header unit type (string like "cm-1", "nm" or HeaderUnit enum)
Returns:
Human-readable axis label string
Examples:
>>> get_axis_label("cm-1")
'Wavenumber (cm⁻¹)'
>>> get_axis_label(HeaderUnit.WAVELENGTH)
'Wavelength (nm)'
>>> get_axis_label("unknown") # Falls back gracefully
'Features'
"""
try:
normalized = normalize_header_unit(unit) if isinstance(unit, str) else unit
return AXIS_LABELS.get(normalized, DEFAULT_AXIS_LABEL)
except ValueError:
return DEFAULT_AXIS_LABEL
[docs]
def get_x_values_and_label(
headers: Optional[List[str]],
header_unit: Union[str, HeaderUnit],
n_features: int
) -> Tuple[np.ndarray, str]:
"""Get x-axis values and label from headers and unit.
This is the main utility function for chart x-axis setup. It handles:
- Numeric headers (wavelengths/wavenumbers) → parsed float array
- Non-numeric headers → fallback to indices
- Missing or mismatched headers → fallback to indices
Args:
headers: List of header strings (wavelengths, feature names, etc.)
header_unit: Header unit type ("cm-1", "nm", "none", "text", "index")
n_features: Number of features (for fallback and validation)
Returns:
Tuple of (x_values array, axis_label string)
Examples:
>>> x_vals, label = get_x_values_and_label(["4000", "4500", "5000"], "cm-1", 3)
>>> x_vals
array([4000., 4500., 5000.])
>>> label
'Wavenumber (cm⁻¹)'
>>> x_vals, label = get_x_values_and_label(None, "cm-1", 5)
>>> x_vals
array([0, 1, 2, 3, 4])
>>> label
'Features'
"""
# No headers or mismatched length - use indices with generic label
if headers is None or len(headers) != n_features:
return np.arange(n_features), DEFAULT_AXIS_LABEL
# Try to normalize the unit
try:
normalized = normalize_header_unit(header_unit) if isinstance(header_unit, str) else header_unit
except ValueError:
normalized = HeaderUnit.NONE
# For numeric unit types, try to parse headers as floats
if normalized in (HeaderUnit.WAVENUMBER, HeaderUnit.WAVELENGTH, HeaderUnit.NONE, HeaderUnit.INDEX):
try:
x_values = np.array([float(h) for h in headers])
return x_values, AXIS_LABELS.get(normalized, DEFAULT_AXIS_LABEL)
except (ValueError, TypeError):
# Headers not numeric - fall back to indices
return np.arange(n_features), DEFAULT_AXIS_LABEL
# For TEXT unit - use indices but with TEXT label
return np.arange(n_features), AXIS_LABELS.get(normalized, DEFAULT_AXIS_LABEL)
[docs]
def should_invert_x_axis(x_values: np.ndarray) -> bool:
"""Check if x-axis should be inverted (for wavenumber convention).
In spectroscopy, wavenumber (cm⁻¹) axes are often displayed in
descending order (high to low) if the data is ordered that way.
Args:
x_values: Array of x-axis values
Returns:
True if x_values are in descending order and should be displayed as such
"""
if len(x_values) < 2:
return False
return x_values[0] > x_values[-1]
[docs]
def apply_x_axis_limits(ax, x_values: np.ndarray) -> None:
"""Apply appropriate x-axis limits to preserve data ordering.
Matplotlib may auto-sort axis values. This function sets explicit
limits to preserve the original ordering (ascending or descending).
Args:
ax: Matplotlib Axes object
x_values: Array of x-axis values
"""
if len(x_values) > 1 and x_values[0] > x_values[-1]:
ax.set_xlim(x_values[0], x_values[-1])