Source code for nirs4all.data.metadata

"""
Metadata management for SpectroDataset.

This module contains Metadata class for managing sample-level auxiliary data.
Metadata has one row per sample and aligns with the indexer's row indices.
"""

from typing import Dict, List, Optional, Union, Literal, Any
import numpy as np
import polars as pl
import pandas as pd
from sklearn.preprocessing import LabelEncoder


[docs] class Metadata: """Lightweight metadata manager for sample-level auxiliary data.""" def __init__(self): """Initialize empty metadata block.""" self.df: Optional[pl.DataFrame] = None self._numeric_cache: Dict[str, tuple[np.ndarray, Dict]] = {} self._row_counter: int = 0 # Track next row ID
[docs] def add_metadata(self, data: Union[np.ndarray, pl.DataFrame, pd.DataFrame], headers: Optional[List[str]] = None) -> None: """ Add metadata rows. Args: data: 2D array (n_samples, n_cols) or DataFrame headers: Column names (required if data is ndarray) """ if data is None: return # Convert input to DataFrame if isinstance(data, np.ndarray): if data.size == 0: return if data.ndim == 1: data = data.reshape(-1, 1) if headers is None: headers = [f"meta_{i}" for i in range(data.shape[1])] new_df = pl.DataFrame({col: data[:, i] for i, col in enumerate(headers)}) elif isinstance(data, pl.DataFrame): new_df = data.clone() elif isinstance(data, pd.DataFrame): new_df = pl.from_pandas(data) else: raise ValueError(f"Unsupported data type: {type(data)}") if len(new_df) == 0: return # Add row_id column n_rows = len(new_df) row_ids = list(range(self._row_counter, self._row_counter + n_rows)) new_df = new_df.insert_column(0, pl.Series("row_id", row_ids, dtype=pl.Int32)) # Append or initialize if self.df is None: self.df = new_df else: # Use diagonal strategy to handle different columns self.df = pl.concat([self.df, new_df], how="diagonal_relaxed") self._row_counter += n_rows # Clear numeric cache when data changes self._numeric_cache.clear()
[docs] def get(self, indices: Optional[Union[List[int], np.ndarray]] = None, columns: Optional[List[str]] = None) -> pl.DataFrame: """ Get metadata as DataFrame. Args: indices: Row indices to select (None = all) columns: Columns to return (None = all except row_id) Returns: Polars DataFrame (without row_id column) """ if self.df is None: return pl.DataFrame() result = self.df # Filter by row indices if indices is not None and len(indices) > 0: result = result.filter(pl.col("row_id").is_in(indices)) # Select columns if columns is not None: missing = [c for c in columns if c not in result.columns] if missing: raise ValueError(f"Columns not found: {missing}") result = result.select(["row_id"] + columns) # Remove row_id from output return result.select([c for c in result.columns if c != "row_id"])
[docs] def get_column(self, column: str, indices: Optional[Union[List[int], np.ndarray]] = None) -> np.ndarray: """ Get single column as numpy array. Args: column: Column name indices: Row indices to select (None = all) Returns: Numpy array of column values """ if self.df is None: raise ValueError("No metadata available") if column not in self.df.columns: raise ValueError(f"Column '{column}' not found. Available: {self.columns}") result = self.df if indices is not None and len(indices) > 0: result = result.filter(pl.col("row_id").is_in(indices)) return result[column].to_numpy()
[docs] def to_numeric(self, column: str, indices: Optional[Union[List[int], np.ndarray]] = None, method: Literal["label", "onehot"] = "label") -> tuple[np.ndarray, Dict]: """ Convert categorical column to numeric encoding. Args: column: Column name indices: Row indices (None = all) method: "label" for label encoding, "onehot" for one-hot Returns: (numeric_array, encoding_info) tuple where encoding_info contains method details and class mappings """ if self.df is None: raise ValueError("No metadata available") cache_key = f"{column}_{method}" # Get column data col_data = self.get_column(column, indices=None) # Get full column for encoding # Check if already cached if cache_key in self._numeric_cache: full_numeric, encoding_info = self._numeric_cache[cache_key] # Filter to requested indices if indices is not None and len(indices) > 0: # Map indices to positions in full data all_row_ids = self.df["row_id"].to_numpy() positions = [np.where(all_row_ids == idx)[0][0] for idx in indices] return full_numeric[positions], encoding_info return full_numeric.copy(), encoding_info # Create encoding if method == "label": # Check if already numeric if np.issubdtype(col_data.dtype, np.number): numeric = col_data.astype(np.float32) encoding_info = {"method": "numeric", "dtype": str(col_data.dtype)} else: # Use LabelEncoder encoder = LabelEncoder() numeric = encoder.fit_transform(col_data).astype(np.float32) encoding_info = { "method": "label", "classes": encoder.classes_.tolist() } elif method == "onehot": # Get unique values unique_vals = np.unique(col_data) n_classes = len(unique_vals) n_samples = len(col_data) # Create one-hot matrix numeric = np.zeros((n_samples, n_classes), dtype=np.float32) val_to_idx = {val: i for i, val in enumerate(unique_vals)} for i, val in enumerate(col_data): numeric[i, val_to_idx[val]] = 1.0 encoding_info = { "method": "onehot", "classes": unique_vals.tolist() } else: raise ValueError(f"Unknown method: {method}. Use 'label' or 'onehot'") # Cache for consistency self._numeric_cache[cache_key] = (numeric.copy(), encoding_info) # Filter to requested indices if indices is not None and len(indices) > 0: all_row_ids = self.df["row_id"].to_numpy() positions = [np.where(all_row_ids == idx)[0][0] for idx in indices] return numeric[positions], encoding_info return numeric, encoding_info
[docs] def update_metadata(self, indices: Union[List[int], np.ndarray], column: str, values: Union[List, np.ndarray]) -> None: """ Update metadata values for specific rows. Args: indices: Row indices to update column: Column name values: New values (must match length of indices) """ if self.df is None: raise ValueError("No metadata available") if column not in self.df.columns: raise ValueError(f"Column '{column}' not found") if len(indices) != len(values): raise ValueError(f"Length mismatch: {len(indices)} indices vs {len(values)} values") # Clear cache since data is changing self._numeric_cache.clear() # Update using Polars - more efficient approach # Create a mapping dict update_dict = dict(zip(indices, values)) # Apply updates self.df = self.df.with_columns( pl.col("row_id").replace(update_dict, default=pl.col("row_id")).alias("_temp_update_key") ) # Use the mapping to update values for idx, val in update_dict.items(): self.df = self.df.with_columns( pl.when(pl.col("row_id") == idx) .then(pl.lit(val)) .otherwise(pl.col(column)) .alias(column) ) # Remove temp column if it exists if "_temp_update_key" in self.df.columns: self.df = self.df.drop("_temp_update_key")
[docs] def add_column(self, column: str, values: Union[List, np.ndarray]) -> None: """ Add new metadata column. Args: column: Column name values: Column values (must match number of rows) """ if self.df is None: raise ValueError("No metadata available. Add metadata first.") if len(values) != len(self.df): raise ValueError(f"Values length {len(values)} != metadata rows {len(self.df)}") if column in self.df.columns: raise ValueError(f"Column '{column}' already exists") # Add column self.df = self.df.with_columns(pl.Series(column, values)) # Clear cache since structure changed self._numeric_cache.clear()
@property def num_rows(self) -> int: """Number of metadata rows.""" return 0 if self.df is None else len(self.df) @property def columns(self) -> List[str]: """List of metadata column names (excluding row_id).""" if self.df is None: return [] return [c for c in self.df.columns if c != "row_id"] def __repr__(self) -> str: if self.df is None: return "Metadata(empty)" return f"Metadata(rows={self.num_rows}, columns={self.columns})" def __str__(self) -> str: return self.__repr__()