Source code for nirs4all.data.loaders.parquet_loader

"""
Parquet file loader implementation.

This module provides the ParquetLoader class for loading Apache Parquet files.
Requires pyarrow or fastparquet as a dependency.
"""

from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union

import pandas as pd

from .base import (
    FileLoadError,
    FileLoader,
    LoaderResult,
    register_loader,
)


def _check_parquet_available() -> str:
    """Check if a Parquet engine is available.

    Returns:
        Name of the available engine ('pyarrow' or 'fastparquet').

    Raises:
        ImportError: If no Parquet engine is available.
    """
    try:
        import pyarrow
        return "pyarrow"
    except ImportError:
        pass

    try:
        import fastparquet
        return "fastparquet"
    except ImportError:
        pass

    raise ImportError(
        "No Parquet engine available. Install pyarrow or fastparquet: "
        "pip install pyarrow  # or pip install fastparquet"
    )


[docs] @register_loader class ParquetLoader(FileLoader): """Loader for Apache Parquet files. Requires pyarrow or fastparquet to be installed. Supports: - Single Parquet files (.parquet, .pq) - Partitioned datasets (directory of parquet files) - Column selection for efficient loading Parameters: columns: List of column names to load (default: all columns). engine: Parquet engine to use ('auto', 'pyarrow', or 'fastparquet'). filters: Row group filters for predicate pushdown (pyarrow only). header_unit: Unit for headers ('cm-1', 'nm', 'text', etc.) Example: >>> loader = ParquetLoader() >>> result = loader.load( ... Path("data.parquet"), ... columns=["feature_1", "feature_2"], ... ) """ supported_extensions: ClassVar[Tuple[str, ...]] = (".parquet", ".pq") name: ClassVar[str] = "Parquet Loader" priority: ClassVar[int] = 35 # Higher priority for Parquet
[docs] @classmethod def supports(cls, path: Path) -> bool: """Check if this loader supports the given file.""" # Check file extension if path.suffix.lower() in cls.supported_extensions: return True # Check if it's a directory with parquet files (partitioned dataset) if path.is_dir(): parquet_files = list(path.glob("*.parquet")) + list(path.glob("*.pq")) if parquet_files: return True # Check for nested partitions if list(path.glob("**/*.parquet")) or list(path.glob("**/*.pq")): return True return False
[docs] def load( self, path: Path, columns: Optional[List[str]] = None, engine: str = "auto", filters: Optional[List] = None, header_unit: str = "text", data_type: str = "x", **params: Any, ) -> LoaderResult: """Load data from a Parquet file. Args: path: Path to the Parquet file or directory. columns: List of column names to load. If None, loads all columns. engine: Parquet engine ('auto', 'pyarrow', or 'fastparquet'). filters: Row group filters for predicate pushdown (pyarrow only). header_unit: Unit type for headers. data_type: Type of data ('x', 'y', or 'metadata'). **params: Additional parameters passed to read_parquet. Returns: LoaderResult with the loaded data. """ report: Dict[str, Any] = { "file_path": str(path), "format": "parquet", "engine": None, "columns_requested": columns, "columns_loaded": None, "initial_shape": None, "final_shape": None, "na_handling": { "strategy": "remove", "na_detected": False, "nb_removed_rows": 0, "removed_rows_indices": [], }, "warnings": [], "error": None, } try: if not path.exists(): raise FileNotFoundError(f"File or directory not found: {path}") # Determine engine if engine == "auto": try: engine = _check_parquet_available() except ImportError as e: report["error"] = str(e) return LoaderResult(report=report, header_unit=header_unit) report["engine"] = engine # Build read_parquet kwargs read_kwargs: Dict[str, Any] = { "engine": engine, } if columns is not None: read_kwargs["columns"] = columns # Filters only work with pyarrow if filters is not None: if engine == "pyarrow": read_kwargs["filters"] = filters else: report["warnings"].append( "Filters are only supported with pyarrow engine. Ignoring." ) # Add any extra params read_kwargs.update(params) # Load the data try: data = pd.read_parquet(path, **read_kwargs) except ImportError as e: report["error"] = f"Parquet engine not available: {e}" return LoaderResult(report=report, header_unit=header_unit) except Exception as e: report["error"] = f"Failed to read Parquet file: {e}" return LoaderResult(report=report, header_unit=header_unit) report["initial_shape"] = data.shape report["columns_loaded"] = data.columns.tolist() # Ensure column names are strings data.columns = data.columns.astype(str) if data.empty: report["warnings"].append("Loaded DataFrame is empty.") return LoaderResult( data=pd.DataFrame(), report=report, na_mask=pd.Series(dtype=bool), headers=[], header_unit=header_unit, ) # Type conversion for X data if data_type == "x": for col in data.columns: if not pd.api.types.is_numeric_dtype(data[col]): data[col] = pd.to_numeric(data[col], errors="coerce") # Handle NA values na_mask = data.isna().any(axis=1) report["na_handling"]["na_detected"] = bool(na_mask.any()) if na_mask.any(): report["na_handling"]["nb_removed_rows"] = int(na_mask.sum()) report["na_handling"]["removed_rows_indices"] = data.index[na_mask].tolist() data = data[~na_mask].copy() report["final_shape"] = data.shape headers = data.columns.tolist() return LoaderResult( data=data, report=report, na_mask=na_mask, headers=headers, header_unit=header_unit, ) except FileNotFoundError as e: report["error"] = str(e) return LoaderResult(report=report, header_unit=header_unit) except Exception as e: import traceback report["error"] = f"Error loading Parquet file: {e}\n{traceback.format_exc()}" return LoaderResult(report=report, header_unit=header_unit)
[docs] def load_parquet( path, columns: Optional[List[str]] = None, engine: str = "auto", header_unit: str = "text", **params, ): """Load a Parquet file. Convenience function for direct use. Args: path: Path to the Parquet file. columns: Column names to load. engine: Parquet engine to use. header_unit: Unit type for headers. **params: Additional parameters. Returns: Tuple of (DataFrame, report, na_mask, headers, header_unit). """ loader = ParquetLoader() result = loader.load( Path(path), columns=columns, engine=engine, header_unit=header_unit, **params, ) return ( result.data, result.report, result.na_mask, result.headers, result.header_unit, )