"""
Aggregator for sample data aggregation.
This module provides the Aggregator class for aggregating sample data
during loading, with support for various aggregation methods and
custom aggregation functions.
Phase 8 Implementation - Dataset Configuration Roadmap
Section 8.1: Sample Aggregation Enhancements
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import pandas as pd
from nirs4all.core.logging import get_logger
logger = get_logger(__name__)
[docs]
class AggregationMethod(str, Enum):
"""Aggregation method for combining samples."""
MEAN = "mean"
MEDIAN = "median"
VOTE = "vote"
MIN = "min"
MAX = "max"
SUM = "sum"
STD = "std"
FIRST = "first"
LAST = "last"
COUNT = "count"
[docs]
@dataclass
class AggregationConfig:
"""Configuration for sample aggregation.
Attributes:
column: Column name to group by for aggregation.
If True, aggregate by y values.
If None, no aggregation.
method: Aggregation method or custom function.
exclude_outliers: Whether to exclude outliers before aggregation.
outlier_threshold: Z-score threshold for outlier detection.
min_samples: Minimum number of samples per group (groups with fewer are dropped).
custom_function: Optional custom aggregation function.
feature_method: Aggregation method for features (X), if different from targets.
target_method: Aggregation method for targets (Y), if different from features.
"""
column: Optional[Union[str, bool]] = None
method: Union[AggregationMethod, str] = AggregationMethod.MEAN
exclude_outliers: bool = False
outlier_threshold: float = 3.0
min_samples: int = 1
custom_function: Optional[Callable] = None
feature_method: Optional[Union[AggregationMethod, str]] = None
target_method: Optional[Union[AggregationMethod, str]] = None
[docs]
def __post_init__(self):
"""Normalize method values to enum."""
if isinstance(self.method, str):
try:
self.method = AggregationMethod(self.method.lower())
except ValueError:
pass # Keep as string for custom methods
if isinstance(self.feature_method, str):
try:
self.feature_method = AggregationMethod(self.feature_method.lower())
except ValueError:
pass
if isinstance(self.target_method, str):
try:
self.target_method = AggregationMethod(self.target_method.lower())
except ValueError:
pass
[docs]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AggregationConfig":
"""Create from configuration dictionary.
Args:
config: Configuration dictionary with aggregation settings.
Returns:
AggregationConfig instance.
"""
aggregate = config.get("aggregate")
if aggregate is None:
return cls(column=None)
return cls(
column=aggregate,
method=config.get("aggregate_method", AggregationMethod.MEAN),
exclude_outliers=config.get("aggregate_exclude_outliers", False),
outlier_threshold=config.get("aggregate_outlier_threshold", 3.0),
min_samples=config.get("aggregate_min_samples", 1),
feature_method=config.get("aggregate_feature_method"),
target_method=config.get("aggregate_target_method"),
)
[docs]
def is_enabled(self) -> bool:
"""Check if aggregation is enabled."""
return self.column is not None
[docs]
class AggregationError(Exception):
"""Exception raised when aggregation fails."""
pass
[docs]
class Aggregator:
"""Aggregates sample data during loading.
Supports grouping by metadata columns, target values, or sample IDs,
with configurable aggregation methods for features and targets.
Example:
```python
# Aggregate by sample_id column using mean
config = AggregationConfig(column="sample_id", method="mean")
aggregator = Aggregator(config)
X_agg, y_agg, meta_agg = aggregator.aggregate(X, y, metadata)
# Aggregate with outlier exclusion
config = AggregationConfig(
column="sample_id",
method="mean",
exclude_outliers=True,
outlier_threshold=2.5
)
aggregator = Aggregator(config)
result = aggregator.aggregate(X, y, metadata)
# Custom aggregation function
config = AggregationConfig(
column="sample_id",
custom_function=lambda x: np.percentile(x, 75, axis=0)
)
aggregator = Aggregator(config)
result = aggregator.aggregate(X, y, metadata)
```
"""
def __init__(self, config: AggregationConfig):
"""Initialize aggregator.
Args:
config: Aggregation configuration.
"""
self.config = config
self._custom_functions: Dict[str, Callable] = {}
[docs]
def register_function(self, name: str, func: Callable) -> None:
"""Register a custom aggregation function.
Args:
name: Name to reference the function.
func: Aggregation function that takes array and returns aggregated value.
"""
self._custom_functions[name] = func
[docs]
def aggregate(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
metadata: Optional[pd.DataFrame] = None,
group_column: Optional[str] = None,
) -> tuple:
"""Aggregate data by groups.
Args:
X: Feature array of shape (n_samples, n_features).
y: Optional target array of shape (n_samples,) or (n_samples, n_targets).
metadata: Optional metadata DataFrame.
group_column: Override column to group by.
Returns:
Tuple of (X_aggregated, y_aggregated, metadata_aggregated).
Elements are None if not provided in input.
Raises:
AggregationError: If aggregation fails.
"""
if not self.config.is_enabled() and group_column is None:
return X, y, metadata
# Determine grouping column
column = group_column or self.config.column
# Get group labels
group_labels = self._get_group_labels(column, X, y, metadata)
if group_labels is None:
logger.warning("Could not determine group labels for aggregation")
return X, y, metadata
# Validate group labels
if len(group_labels) != len(X):
raise AggregationError(
f"Group labels length ({len(group_labels)}) does not match "
f"data length ({len(X)})"
)
# Get unique groups
unique_groups = np.unique(group_labels)
n_groups = len(unique_groups)
logger.debug(f"Aggregating {len(X)} samples into {n_groups} groups")
# Aggregate features
X_agg = self._aggregate_array(
X,
group_labels,
unique_groups,
self.config.feature_method or self.config.method
)
# Aggregate targets
y_agg = None
if y is not None:
target_method = self.config.target_method or self.config.method
# For classification (vote), use mode; for regression, use specified method
if target_method == AggregationMethod.VOTE:
y_agg = self._aggregate_vote(y, group_labels, unique_groups)
else:
y_agg = self._aggregate_array(
y.reshape(-1, 1) if y.ndim == 1 else y,
group_labels,
unique_groups,
target_method
)
if y.ndim == 1:
y_agg = y_agg.ravel()
# Aggregate metadata
meta_agg = None
if metadata is not None:
meta_agg = self._aggregate_metadata(metadata, group_labels, unique_groups)
return X_agg, y_agg, meta_agg
def _get_group_labels(
self,
column: Union[str, bool],
X: np.ndarray,
y: Optional[np.ndarray],
metadata: Optional[pd.DataFrame],
) -> Optional[np.ndarray]:
"""Get group labels for aggregation.
Args:
column: Column name, True for y-based grouping, or False/None for no grouping.
X: Feature array.
y: Target array.
metadata: Metadata DataFrame.
Returns:
Array of group labels, or None if cannot determine.
"""
if column is True:
# Group by target values
if y is None:
logger.warning("Cannot aggregate by y values: y is None")
return None
return y.astype(str) if y.dtype == object else y
if isinstance(column, str):
# Group by metadata column
if metadata is None:
logger.warning(f"Cannot aggregate by '{column}': metadata is None")
return None
if column not in metadata.columns:
# Try common column name variations
possible_names = [
column,
column.lower(),
column.upper(),
column.replace("_", ""),
column.replace("-", "_"),
]
found_column = None
for name in possible_names:
if name in metadata.columns:
found_column = name
break
# Case-insensitive search
for col in metadata.columns:
if col.lower() == name.lower():
found_column = col
break
if found_column:
break
if found_column is None:
logger.warning(
f"Aggregation column '{column}' not found in metadata. "
f"Available columns: {list(metadata.columns)}"
)
return None
column = found_column
return metadata[column].values
return None
def _aggregate_array(
self,
data: np.ndarray,
group_labels: np.ndarray,
unique_groups: np.ndarray,
method: Union[AggregationMethod, str, Callable],
) -> np.ndarray:
"""Aggregate an array by groups.
Args:
data: Array to aggregate, shape (n_samples, n_features).
group_labels: Array of group labels.
unique_groups: Unique group values.
method: Aggregation method.
Returns:
Aggregated array of shape (n_groups, n_features).
"""
n_groups = len(unique_groups)
n_features = data.shape[1] if data.ndim > 1 else 1
if data.ndim == 1:
data = data.reshape(-1, 1)
result = np.zeros((n_groups, n_features), dtype=data.dtype)
# Get the aggregation function
agg_func = self._get_aggregation_function(method)
for i, group in enumerate(unique_groups):
mask = group_labels == group
group_data = data[mask]
# Exclude outliers if configured
if self.config.exclude_outliers and len(group_data) > 2:
group_data = self._remove_outliers(group_data)
# Check minimum samples
if len(group_data) < self.config.min_samples:
logger.warning(
f"Group '{group}' has {len(group_data)} samples, "
f"less than minimum {self.config.min_samples}. Using available data."
)
if len(group_data) == 0:
result[i] = np.nan
else:
result[i] = agg_func(group_data, axis=0)
return result
def _aggregate_vote(
self,
y: np.ndarray,
group_labels: np.ndarray,
unique_groups: np.ndarray,
) -> np.ndarray:
"""Aggregate targets using majority voting.
Args:
y: Target array.
group_labels: Array of group labels.
unique_groups: Unique group values.
Returns:
Aggregated target array.
"""
result = np.zeros(len(unique_groups), dtype=y.dtype)
for i, group in enumerate(unique_groups):
mask = group_labels == group
group_y = y[mask]
if len(group_y) == 0:
result[i] = np.nan if np.issubdtype(y.dtype, np.floating) else 0
else:
# Get mode (most common value)
values, counts = np.unique(group_y, return_counts=True)
result[i] = values[np.argmax(counts)]
return result
def _aggregate_metadata(
self,
metadata: pd.DataFrame,
group_labels: np.ndarray,
unique_groups: np.ndarray,
) -> pd.DataFrame:
"""Aggregate metadata by groups.
For each group, takes the first row's metadata values.
Args:
metadata: Metadata DataFrame.
group_labels: Array of group labels.
unique_groups: Unique group values.
Returns:
Aggregated metadata DataFrame.
"""
rows = []
for group in unique_groups:
mask = group_labels == group
indices = np.where(mask)[0]
if len(indices) > 0:
# Take first row of each group
rows.append(metadata.iloc[indices[0]])
return pd.DataFrame(rows).reset_index(drop=True)
def _get_aggregation_function(
self,
method: Union[AggregationMethod, str, Callable]
) -> Callable:
"""Get the aggregation function for a method.
Args:
method: Aggregation method name or callable.
Returns:
Callable that performs aggregation.
"""
if callable(method):
return method
if self.config.custom_function is not None:
return self.config.custom_function
# Check custom registered functions
method_str = method.value if isinstance(method, AggregationMethod) else str(method)
if method_str in self._custom_functions:
return self._custom_functions[method_str]
# Built-in methods
method_map = {
AggregationMethod.MEAN: lambda x, axis=0: np.nanmean(x, axis=axis),
AggregationMethod.MEDIAN: lambda x, axis=0: np.nanmedian(x, axis=axis),
AggregationMethod.MIN: lambda x, axis=0: np.nanmin(x, axis=axis),
AggregationMethod.MAX: lambda x, axis=0: np.nanmax(x, axis=axis),
AggregationMethod.SUM: lambda x, axis=0: np.nansum(x, axis=axis),
AggregationMethod.STD: lambda x, axis=0: np.nanstd(x, axis=axis),
AggregationMethod.FIRST: lambda x, axis=0: x[0] if len(x) > 0 else np.nan,
AggregationMethod.LAST: lambda x, axis=0: x[-1] if len(x) > 0 else np.nan,
AggregationMethod.COUNT: lambda x, axis=0: np.sum(~np.isnan(x), axis=axis),
}
if isinstance(method, AggregationMethod):
return method_map.get(method, method_map[AggregationMethod.MEAN])
# Try to match string method
try:
method_enum = AggregationMethod(method_str.lower())
return method_map.get(method_enum, method_map[AggregationMethod.MEAN])
except ValueError:
logger.warning(f"Unknown aggregation method '{method}', using mean")
return method_map[AggregationMethod.MEAN]
def _remove_outliers(self, data: np.ndarray) -> np.ndarray:
"""Remove outlier rows from data using z-score.
Args:
data: Array of shape (n_samples, n_features).
Returns:
Array with outlier rows removed.
"""
if len(data) <= 2:
return data
# Calculate z-scores across features
mean = np.nanmean(data, axis=0)
std = np.nanstd(data, axis=0)
# Avoid division by zero
std = np.where(std == 0, 1, std)
z_scores = np.abs((data - mean) / std)
# A row is an outlier if any feature exceeds threshold
max_z = np.nanmax(z_scores, axis=1)
mask = max_z <= self.config.outlier_threshold
return data[mask]
[docs]
def aggregate_data(
X: np.ndarray,
y: Optional[np.ndarray] = None,
metadata: Optional[pd.DataFrame] = None,
column: Optional[Union[str, bool]] = None,
method: Union[str, AggregationMethod] = "mean",
exclude_outliers: bool = False,
**kwargs,
) -> tuple:
"""Convenience function to aggregate data.
Args:
X: Feature array.
y: Optional target array.
metadata: Optional metadata DataFrame.
column: Column to group by (str), or True for y-based grouping.
method: Aggregation method.
exclude_outliers: Whether to exclude outliers.
**kwargs: Additional aggregation config options.
Returns:
Tuple of (X_aggregated, y_aggregated, metadata_aggregated).
"""
if column is None:
return X, y, metadata
config = AggregationConfig(
column=column,
method=method,
exclude_outliers=exclude_outliers,
**kwargs,
)
aggregator = Aggregator(config)
return aggregator.aggregate(X, y, metadata)