Source code for nirs4all.data.loaders.loader

# dataset_loader.py

import hashlib
import json
from pathlib import Path
import numpy as np
import pandas as pd
from nirs4all.data.config_parser import parse_config
from nirs4all.data.dataset import SpectroDataset
from nirs4all.data.loaders.csv_loader import load_csv
from nirs4all.data.signal_type import SignalType, normalize_signal_type
from typing import Any, Dict, List, Tuple, Union, Optional

# Import the new loader system
from nirs4all.data.loaders.base import LoaderRegistry, FormatNotSupportedError


[docs] def create_synthetic_dataset(config: Dict) -> SpectroDataset: """ Create a synthetic SpectroDataset for testing purposes. Args: config: Dictionary with keys: - X: Feature matrix (n_samples, n_features) - y: Target values (n_samples,) - folds: Number of CV folds - train/val/test: Split ratios - random_state: Random seed Returns: SpectroDataset: Synthetic dataset ready for pipeline use """ X = config['X'] y = config['y'] # Create synthetic dataset object with a proper string name dataset = SpectroDataset(name="synthetic_test_dataset") # Split the data into train and test partitions # Use the ratios from config, defaulting to 80/20 split n_samples = X.shape[0] train_ratio = config.get('train', 0.8) n_train = int(n_samples * train_ratio) # Split indices indices = np.arange(n_samples) if 'random_state' in config: np.random.seed(config['random_state']) indices = np.random.permutation(indices) train_indices = indices[:n_train] test_indices = indices[n_train:] # Split data X_train = X[train_indices] X_test = X[test_indices] y_train = y[train_indices] y_test = y[test_indices] # Add samples with partition information if len(X_train) > 0: dataset.add_samples(X_train, {"partition": "train"}) if len(X_test) > 0: dataset.add_samples(X_test, {"partition": "test"}) # Add targets if len(y_train) > 0: dataset.add_targets(y_train) if len(y_test) > 0: dataset.add_targets(y_test) return dataset
def _merge_params(local_params, handler_params, global_params): """ Merge parameters from local, handler, and global scopes. Parameters: - local_params (dict): Local parameters specific to the data subset. - handler_params (dict): Parameters specific to the handler. - global_params (dict): Global parameters that apply to all handlers. Returns: - dict: Merged parameters with precedence: local > handler > global. """ merged_params = {} if global_params is None else global_params.copy() if handler_params is not None: merged_params.update(handler_params) if local_params is not None: merged_params.update(local_params) return merged_params # Known loading parameter keys that can appear at root level of config _LOADING_PARAM_KEYS = frozenset({ 'delimiter', 'decimal_separator', 'has_header', 'na_policy', 'header_unit', 'categorical_mode' }) def _get_effective_global_params(config: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Get effective global params by merging root-level loading params with global_params. Root-level params have lowest precedence (global_params overrides them). This allows users to write simpler configs like: {"x_train": "path.csv", "delimiter": ",", "has_header": True} instead of: {"x_train": "path.csv", "global_params": {"delimiter": ",", "has_header": True}} Parameters: - config (dict): The full configuration dictionary. Returns: - dict or None: Merged params with precedence: global_params > root-level params. """ # Extract known loading params from root level root_params = {k: v for k, v in config.items() if k in _LOADING_PARAM_KEYS} global_params = config.get('global_params') if not root_params and not global_params: return None # Merge: root_params (lowest) < global_params (higher) effective = root_params.copy() if root_params else {} if global_params: effective.update(global_params) return effective if effective else None def _load_file_with_registry( file_path: Union[str, Path], header_unit: str = "cm-1", data_type: str = "x", **params: Any, ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], Optional[pd.Series], List[str], str]: """Load a file using the LoaderRegistry for format detection. This function provides automatic format detection and loading using the registered file loaders. It falls back to CSV loading for unknown formats. Args: file_path: Path to the file to load. header_unit: Unit for headers ('cm-1', 'nm', etc.). data_type: Type of data ('x', 'y', or 'metadata'). **params: Additional loading parameters. Returns: Tuple of (DataFrame, report, na_mask, headers, header_unit). """ path = Path(file_path) if isinstance(file_path, str) else file_path # Try to use the registry for format detection try: registry = LoaderRegistry.get_instance() loader = registry.get_loader(path) result = loader.load( path, header_unit=header_unit, data_type=data_type, **params, ) return ( result.data, result.report, result.na_mask, result.headers, result.header_unit, ) except FormatNotSupportedError: # Fall back to CSV loader for unknown formats return load_csv(file_path, header_unit=header_unit, data_type=data_type, **params) except Exception as e: # On any other error, try CSV as a fallback try: return load_csv(file_path, header_unit=header_unit, data_type=data_type, **params) except Exception: # If CSV also fails, re-raise the original error raise e
[docs] def load_XY(x_path, x_filter, x_params, y_path, y_filter, y_params, m_path=None, m_filter=None, m_params=None): """ Load X, Y, and metadata from single paths. For multi-source, this will be called multiple times. Parameters: - x_path (str): Single path to X data file. - x_filter: Filter to apply to X data (not implemented yet). - x_params (dict): Parameters for loading X data, including: - header_unit: Unit for headers ("cm-1", "nm", "none", "text", "index") - signal_type: Signal type ("absorbance", "reflectance", "reflectance%", etc.) - delimiter, decimal_separator, has_header, na_policy, etc. - y_path (str): Path to the Y data file (can be None). - y_filter: Filter to apply to Y data (or indices if y_path is None). - y_params (dict): Parameters for loading Y data. - m_path (str): Path to metadata file (can be None). - m_filter: Filter to apply to metadata (not implemented yet). - m_params (dict): Parameters for loading metadata. Returns: - tuple: (x, y, m, x_headers, m_headers, x_header_unit, x_signal_type) where: - x, y, m are numpy arrays/DataFrames - x_headers, m_headers are lists of column names - x_header_unit is the unit string for X headers ("cm-1", "nm", "none", "text", "index") - x_signal_type is the signal type (SignalType enum or None for auto-detect) Raises: - ValueError: If data is invalid or if there are inconsistencies. """ if x_path is None: raise ValueError("Invalid x definition: x_path is None") # Set default parameters if 'categorical_mode' not in x_params: x_params['categorical_mode'] = 'auto' if 'data_type' not in x_params: x_params['data_type'] = 'x' # Extract header_unit from params (default to cm-1) x_header_unit = x_params.pop('header_unit', 'cm-1') # Extract signal_type from params (default to None for auto-detect) x_signal_type_raw = x_params.pop('signal_type', None) x_signal_type: Optional[SignalType] = None if x_signal_type_raw is not None: x_signal_type = normalize_signal_type(x_signal_type_raw) # Load X data using format-aware loader try: x_df, x_report, x_na_mask, x_headers, x_unit = _load_file_with_registry( x_path, header_unit=x_header_unit, **x_params ) if x_report.get("error") is not None or x_df is None: raise ValueError(f"Failed to load X data from {x_path}: {x_report.get('error', 'Unknown error')}") except Exception as e: raise ValueError(f"Error loading X data from {x_path}: {str(e)}") if x_filter is not None: raise NotImplementedError("Auto-filtering not implemented yet") # Load Y data if y_path is None and y_filter is None: # No Y data to extract - create empty Y array with same number of rows as X y_df = pd.DataFrame(index=x_df.index) # Empty DataFrame with matching index elif y_path is None: # Y is a subset of X if not all(isinstance(i, int) for i in y_filter): raise ValueError("Invalid y definition: y_filter is not a list of integers") if x_df.shape[1] <= max(y_filter): raise ValueError(f"Y filter indices {y_filter} exceed X columns ({x_df.shape[1]})") # Extract Y from X and remove Y columns from X y_df = x_df.iloc[:, y_filter] x_df = x_df.drop(x_df.columns[y_filter], axis=1) else: # Y is in a separate file try: y_params_copy = y_params.copy() if 'categorical_mode' not in y_params_copy: y_params_copy['categorical_mode'] = 'auto' if 'data_type' not in y_params_copy: y_params_copy['data_type'] = 'y' y_df, y_report, y_na_mask, _, _ = _load_file_with_registry(y_path, **y_params_copy) if y_report.get("error") is not None or y_df is None: raise ValueError(f"Failed to load Y data from {y_path}: {y_report.get('error', 'Unknown error')}") except Exception as e: raise ValueError(f"Error loading Y data from {y_path}: {str(e)}") if y_filter is not None: if not all(isinstance(i, int) for i in y_filter): raise ValueError("Invalid y_filter: must be list of integers") if y_df.shape[1] <= max(y_filter): raise ValueError(f"Y filter indices {y_filter} exceed Y columns ({y_df.shape[1]})") y_df = y_df.iloc[:, y_filter] # Ensure same number of rows (only check if Y has data) if not y_df.empty and x_df.shape[0] != y_df.shape[0]: raise ValueError(f"Row count mismatch: X({x_df.shape[0]}) Y({y_df.shape[0]})") # Load metadata if provided m_df = pd.DataFrame() m_headers = [] if m_path is not None: try: if m_params is None: m_params = {} m_params_copy = m_params.copy() if 'categorical_mode' not in m_params_copy: m_params_copy['categorical_mode'] = 'preserve' # Keep original types for metadata if 'data_type' not in m_params_copy: m_params_copy['data_type'] = 'metadata' # Use 'remove' policy but we'll ignore the removed rows for metadata # (we want to keep all metadata rows even if some columns have NAs) if 'na_policy' not in m_params_copy: m_params_copy['na_policy'] = 'remove' m_df_temp, m_report, m_na_mask, m_headers, _ = _load_file_with_registry(m_path, **m_params_copy) # For metadata, we want to keep ALL rows including those with NAs # So we reload the data without NA row removal if rows were removed if m_report.get('na_handling', {}).get('nb_removed_rows', 0) > 0: # Rows were removed - reload with explicit na_filter=False to keep everything m_params_no_na_removal = m_params_copy.copy() # We can't directly disable NA removal in load_csv, so we use pandas directly import csv read_csv_kwargs = { 'sep': m_report['delimiter'], 'decimal': m_report['decimal_separator'], 'header': 0 if m_report['has_header'] else None, 'na_filter': True, # Still detect NAs but don't remove them 'keep_default_na': True, 'engine': 'python', } m_df = pd.read_csv(m_path, **read_csv_kwargs) m_df.columns = m_df.columns.astype(str) m_headers = m_df.columns.tolist() else: m_df = m_df_temp if m_report.get("error") is not None or m_df is None: raise ValueError(f"Failed to load metadata from {m_path}: {m_report.get('error', 'Unknown error')}") except Exception as e: raise ValueError(f"Error loading metadata from {m_path}: {str(e)}") if m_filter is not None: raise NotImplementedError("Metadata filtering not implemented yet") # Ensure metadata has same number of rows as X if not m_df.empty and x_df.shape[0] != m_df.shape[0]: raise ValueError(f"Row count mismatch: X({x_df.shape[0]}) Metadata({m_df.shape[0]})") # Update x_headers after potential column removal (if Y was extracted from X) x_headers = x_df.columns.tolist() # Convert to numpy arrays try: x = x_df.astype(np.float32).values if not x_df.empty else np.empty((0, x_df.shape[1]), dtype=np.float32) y = y_df.values if not y_df.empty else np.empty((x_df.shape[0], 0)) # Match X rows but 0 columns # Keep metadata as DataFrame (don't convert to numeric) m = m_df if not m_df.empty else None except Exception as e: raise ValueError(f"Error converting data to numpy arrays: {str(e)}") return x, y, m, x_headers, m_headers, x_unit, x_signal_type
[docs] def handle_data(config, t_set): """ Handle data loading for a given dataset type (train, test). Supports both single-source and multi-source datasets. Parameters: - config (dict): Data configuration dictionary. - t_set (str): The dataset type ('train', 'test'). Returns: - tuple: (x, y, m, x_headers, m_headers, x_header_unit, x_signal_type) where: - x is numpy array or list of arrays - y is numpy array - m is DataFrame or None (metadata) - x_headers is list of column names or list of lists for multi-source - m_headers is list of metadata column names - x_header_unit is string or list of strings for multi-source ("cm-1", "nm", "none", "text", "index") - x_signal_type is SignalType or list of SignalType for multi-source (None for auto-detect) """ if config is None: raise ValueError(f"Configuration for {t_set} dataset is None") if not isinstance(config, dict): raise ValueError(f"Invalid config type for {t_set}: {type(config)}") # Get effective global params (includes root-level loading params) effective_global_params = _get_effective_global_params(config) # Get paths x_path = config.get(f'{t_set}_x') y_path = config.get(f'{t_set}_y') m_path = config.get(f'{t_set}_group') # Metadata uses 'group' key # Check if we already have numpy arrays (not file paths) if isinstance(x_path, np.ndarray): # Data is already loaded as numpy arrays x_array = x_path y_array = y_path if isinstance(y_path, np.ndarray) else None m_data = m_path if isinstance(m_path, (pd.DataFrame, np.ndarray)) else None # Generate simple headers if isinstance(x_array, np.ndarray): x_headers = [f"feature_{i}" for i in range(x_array.shape[1] if x_array.ndim > 1 else 1)] else: x_headers = [] m_headers = [] if isinstance(m_data, pd.DataFrame): m_headers = list(m_data.columns) elif isinstance(m_data, np.ndarray) and m_data.ndim > 1: m_headers = [f"meta_{i}" for i in range(m_data.shape[1])] # For pre-loaded arrays, use defaults or config values from nirs4all.data._features import HeaderUnit x_header_unit = HeaderUnit.WAVENUMBER.value # Check for signal_type in config params for pre-loaded arrays x_params = config.get(f'{t_set}_x_params') or effective_global_params or {} x_signal_type = None if 'signal_type' in x_params: x_signal_type = normalize_signal_type(x_params['signal_type']) return x_array, y_array, m_data, x_headers, m_headers, x_header_unit, x_signal_type x_filter = config.get(f'{t_set}_x_filter') y_filter = config.get(f'{t_set}_y_filter') m_filter = config.get(f'{t_set}_group_filter') # Handle multi-source X data if isinstance(x_path, list): x_arrays = [] headers_arrays = [] header_units = [] signal_types = [] y_array = None m_data = None m_headers = [] # Check if we have per-source params x_params_config = config.get(f'{t_set}_x_params') for i, single_x_path in enumerate(x_path): # Determine params for this source if isinstance(x_params_config, list) and i < len(x_params_config): # Per-source params provided source_x_params = _merge_params(x_params_config[i], config.get(f'{t_set}_params'), effective_global_params) elif isinstance(x_params_config, dict): # Check if dict contains list of units or signal_types for multi-source source_params = x_params_config.copy() # Handle header_unit list if 'header_unit' in x_params_config and isinstance(x_params_config['header_unit'], list): if i < len(x_params_config['header_unit']): source_params['header_unit'] = x_params_config['header_unit'][i] else: source_params['header_unit'] = "cm-1" # Handle signal_type list if 'signal_type' in x_params_config and isinstance(x_params_config['signal_type'], list): if i < len(x_params_config['signal_type']): source_params['signal_type'] = x_params_config['signal_type'][i] else: source_params['signal_type'] = None source_x_params = _merge_params(source_params, config.get(f'{t_set}_params'), effective_global_params) else: # No params or unsupported format source_x_params = _merge_params(None, config.get(f'{t_set}_params'), effective_global_params) y_params = _merge_params(config.get(f'{t_set}_y_params'), config.get(f'{t_set}_params'), effective_global_params) m_params = _merge_params(config.get(f'{t_set}_group_params'), config.get(f'{t_set}_params'), effective_global_params) try: # For multi-source, only the first source should handle Y and metadata extraction if i == 0: x_single, y_array, m_data, x_headers, m_headers, x_unit, x_sig_type = load_XY( single_x_path, x_filter, source_x_params, y_path, y_filter, y_params, m_path, m_filter, m_params ) else: # For additional sources, don't extract Y or metadata x_single, _, _, x_headers, _, x_unit, x_sig_type = load_XY( single_x_path, x_filter, source_x_params, None, None, y_params, None, None, None ) x_arrays.append(x_single) headers_arrays.append(x_headers) header_units.append(x_unit) signal_types.append(x_sig_type) except Exception as e: raise ValueError(f"Error loading X source {i} from {single_x_path}: {str(e)}") return x_arrays, y_array, m_data, headers_arrays, m_headers, header_units, signal_types else: # Single source x_params = _merge_params(config.get(f'{t_set}_x_params'), config.get(f'{t_set}_params'), effective_global_params) y_params = _merge_params(config.get(f'{t_set}_y_params'), config.get(f'{t_set}_params'), effective_global_params) m_params = _merge_params(config.get(f'{t_set}_group_params'), config.get(f'{t_set}_params'), effective_global_params) return load_XY(x_path, x_filter, x_params, y_path, y_filter, y_params, m_path, m_filter, m_params)