"""
Partition assigner for dataset configuration.
This module provides flexible partition assignment for DataFrames, supporting
multiple assignment methods including static, column-based, percentage-based,
and index-based partitions.
Example:
>>> assigner = PartitionAssigner()
>>> # Static partition
>>> result = assigner.assign(df, partition="train")
>>> # Column-based partition
>>> result = assigner.assign(df, {
... "column": "split",
... "train_values": ["train", "training"],
... "test_values": ["test", "validation"]
... })
>>> # Percentage-based partition
>>> result = assigner.assign(df, {
... "train": "80%",
... "test": "20%",
... "shuffle": True,
... "random_state": 42
... })
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
import numpy as np
import pandas as pd
[docs]
class PartitionError(Exception):
"""Raised when partition assignment fails."""
pass
# Type alias for partition specification
PartitionSpec = Union[
str, # Static partition: "train", "test", "predict"
Dict[str, Any], # Complex partition specification
None, # Auto-detect (based on file naming, not implemented here)
]
# Partition types
PartitionName = Literal["train", "test", "predict"]
[docs]
@dataclass
class PartitionResult:
"""Result of a partition assignment operation.
Attributes:
train_indices: List of indices assigned to training partition.
test_indices: List of indices assigned to test partition.
predict_indices: List of indices assigned to predict partition (no targets).
train_data: DataFrame subset for training.
test_data: DataFrame subset for testing.
predict_data: DataFrame subset for prediction.
partition_column: Name of column used for partitioning (if column-based).
"""
train_indices: List[int] = field(default_factory=list)
test_indices: List[int] = field(default_factory=list)
predict_indices: List[int] = field(default_factory=list)
train_data: Optional[pd.DataFrame] = None
test_data: Optional[pd.DataFrame] = None
predict_data: Optional[pd.DataFrame] = None
partition_column: Optional[str] = None
@property
def has_train(self) -> bool:
"""Check if training data exists."""
return len(self.train_indices) > 0
@property
def has_test(self) -> bool:
"""Check if test data exists."""
return len(self.test_indices) > 0
@property
def has_predict(self) -> bool:
"""Check if predict data exists."""
return len(self.predict_indices) > 0
[docs]
def get_indices(self, partition: PartitionName) -> List[int]:
"""Get indices for a specific partition."""
if partition == "train":
return self.train_indices
elif partition == "test":
return self.test_indices
elif partition == "predict":
return self.predict_indices
else:
raise PartitionError(f"Unknown partition: {partition}")
[docs]
def get_data(self, partition: PartitionName) -> Optional[pd.DataFrame]:
"""Get data for a specific partition."""
if partition == "train":
return self.train_data
elif partition == "test":
return self.test_data
elif partition == "predict":
return self.predict_data
else:
raise PartitionError(f"Unknown partition: {partition}")
[docs]
class PartitionAssigner:
"""Flexible partition assigner for DataFrames.
Supports multiple partition methods:
- Static: `"train"`, `"test"`, `"predict"` (assign entire DataFrame)
- Column-based: `{"column": "split", "train_values": [...], "test_values": [...]}`
- Percentage-based: `{"train": "80%", "test": "20%", "shuffle": True}`
- Index-based: `{"train": [0,1,2], "test": [3,4,5]}`
- Index file: `{"train_file": "train_idx.txt", "test_file": "test_idx.txt"}`
Example:
>>> assigner = PartitionAssigner()
>>> result = assigner.assign(df, {"train": "80%", "test": "20%"})
>>> print(len(result.train_data), len(result.test_data))
"""
# Recognized partition names
PARTITION_NAMES = ("train", "test", "predict")
# Recognized train values in column-based partitioning
DEFAULT_TRAIN_VALUES = ("train", "training", "cal", "calibration")
# Recognized test values in column-based partitioning
DEFAULT_TEST_VALUES = ("test", "testing", "val", "validation", "valid")
# Recognized predict values
DEFAULT_PREDICT_VALUES = ("predict", "prediction", "unknown")
def __init__(
self,
default_random_state: Optional[int] = None,
base_path: Optional[Path] = None,
):
"""Initialize the partition assigner.
Args:
default_random_state: Default random state for shuffle operations.
base_path: Base path for resolving relative paths in index files.
"""
self.default_random_state = default_random_state
self.base_path = base_path
[docs]
def assign(
self,
df: pd.DataFrame,
partition: PartitionSpec,
) -> PartitionResult:
"""Assign rows to partitions.
Args:
df: The DataFrame to partition.
partition: Partition specification. Can be:
- str: Static partition ("train", "test", "predict")
- dict: Complex partition (column-based, percentage, or index)
- None: No partitioning (returns empty result)
Returns:
PartitionResult with indices and data for each partition.
Raises:
PartitionError: If partition specification is invalid.
"""
if partition is None:
# No partitioning - return empty result
return PartitionResult()
if isinstance(partition, str):
return self._assign_static(df, partition)
if isinstance(partition, dict):
return self._assign_from_dict(df, partition)
raise PartitionError(
f"Unsupported partition type: {type(partition).__name__}. "
f"Expected str, dict, or None."
)
def _assign_static(
self,
df: pd.DataFrame,
partition: str,
) -> PartitionResult:
"""Assign entire DataFrame to a single partition.
Args:
df: The DataFrame to assign.
partition: Partition name ("train", "test", or "predict").
Returns:
PartitionResult with all rows in the specified partition.
"""
partition_lower = partition.lower()
if partition_lower not in self.PARTITION_NAMES:
raise PartitionError(
f"Invalid partition name: '{partition}'. "
f"Expected one of: {self.PARTITION_NAMES}"
)
indices = list(range(len(df)))
result = PartitionResult()
if partition_lower == "train":
result.train_indices = indices
result.train_data = df.copy()
elif partition_lower == "test":
result.test_indices = indices
result.test_data = df.copy()
elif partition_lower == "predict":
result.predict_indices = indices
result.predict_data = df.copy()
return result
def _assign_from_dict(
self,
df: pd.DataFrame,
partition: Dict[str, Any],
) -> PartitionResult:
"""Assign based on dictionary specification.
Detects the partition method from the dictionary keys:
- "column": Column-based partitioning
- "train"/"test" with percentages or lists: Percentage or index-based
- "train_file"/"test_file": Index file-based
"""
# Check for column-based partitioning
if "column" in partition:
return self._assign_by_column(df, partition)
# Check for index file-based partitioning
if "train_file" in partition or "test_file" in partition:
return self._assign_by_index_file(df, partition)
# Check for percentage or index-based partitioning
if "train" in partition or "test" in partition:
train_spec = partition.get("train")
test_spec = partition.get("test")
predict_spec = partition.get("predict")
# Detect if percentage/range-based (string with % or :)
def is_range_or_percentage(spec):
if isinstance(spec, str):
return "%" in spec or ":" in spec
return False
is_range_based = (
is_range_or_percentage(train_spec) or
is_range_or_percentage(test_spec) or
is_range_or_percentage(predict_spec)
)
if is_range_based:
return self._assign_by_percentage(df, partition)
else:
return self._assign_by_indices(df, partition)
raise PartitionError(
f"Cannot determine partition method from specification: {partition}. "
f"Expected 'column', 'train'/'test' with values/percentages, "
f"or 'train_file'/'test_file'."
)
def _assign_by_column(
self,
df: pd.DataFrame,
partition: Dict[str, Any],
) -> PartitionResult:
"""Assign based on column values.
Args:
df: The DataFrame to partition.
partition: Dict with keys:
- column: Column name containing partition labels
- train_values: Values indicating training data
- test_values: Values indicating test data
- predict_values: Values indicating predict data
- unknown_policy: How to handle unknown values ("error", "ignore", "train")
"""
column = partition.get("column")
if column is None:
raise PartitionError("Column-based partition requires 'column' key.")
if column not in df.columns:
raise PartitionError(
f"Partition column '{column}' not found in DataFrame. "
f"Available columns: {df.columns.tolist()[:10]}"
)
# Get value mappings with defaults
train_values = partition.get("train_values", list(self.DEFAULT_TRAIN_VALUES))
test_values = partition.get("test_values", list(self.DEFAULT_TEST_VALUES))
predict_values = partition.get("predict_values", list(self.DEFAULT_PREDICT_VALUES))
unknown_policy = partition.get("unknown_policy", "error")
# Normalize values to lowercase for comparison
train_values_lower = {str(v).lower() for v in train_values}
test_values_lower = {str(v).lower() for v in test_values}
predict_values_lower = {str(v).lower() for v in predict_values}
# Build indices for each partition
train_indices = []
test_indices = []
predict_indices = []
unknown_indices = []
for idx in range(len(df)):
value = str(df.iloc[idx][column]).lower()
if value in train_values_lower:
train_indices.append(idx)
elif value in test_values_lower:
test_indices.append(idx)
elif value in predict_values_lower:
predict_indices.append(idx)
else:
unknown_indices.append(idx)
# Handle unknown values
if unknown_indices:
if unknown_policy == "error":
unknown_values = df.iloc[unknown_indices][column].unique().tolist()
raise PartitionError(
f"Unknown partition values in column '{column}': {unknown_values}. "
f"Expected train: {list(train_values_lower)}, "
f"test: {list(test_values_lower)}, "
f"predict: {list(predict_values_lower)}. "
f"Set unknown_policy='ignore' to skip or 'train' to include in training."
)
elif unknown_policy == "train":
train_indices.extend(unknown_indices)
# else: ignore (leave unknown indices unassigned)
return PartitionResult(
train_indices=train_indices,
test_indices=test_indices,
predict_indices=predict_indices,
train_data=df.iloc[train_indices].copy() if train_indices else None,
test_data=df.iloc[test_indices].copy() if test_indices else None,
predict_data=df.iloc[predict_indices].copy() if predict_indices else None,
partition_column=column,
)
def _assign_by_percentage(
self,
df: pd.DataFrame,
partition: Dict[str, Any],
) -> PartitionResult:
"""Assign based on percentage splits.
Args:
df: The DataFrame to partition.
partition: Dict with keys:
- train: Percentage for training (e.g., "80%", "0:80%")
- test: Percentage for testing (e.g., "20%", "80%:100%")
- predict: Percentage for prediction
- shuffle: Whether to shuffle before splitting (default: False)
- random_state: Random state for shuffling
- stratify: Column name for stratified splitting
"""
n_rows = len(df)
shuffle = partition.get("shuffle", False)
random_state = partition.get("random_state", self.default_random_state)
stratify_column = partition.get("stratify")
# Parse percentages
train_spec = partition.get("train")
test_spec = partition.get("test")
predict_spec = partition.get("predict")
# Calculate indices
indices = np.arange(n_rows)
# Handle stratification
if stratify_column:
if stratify_column not in df.columns:
raise PartitionError(
f"Stratify column '{stratify_column}' not found in DataFrame."
)
indices = self._stratified_shuffle(df, stratify_column, random_state)
elif shuffle:
rng = np.random.RandomState(random_state)
rng.shuffle(indices)
# Parse each partition's percentage
train_range = self._parse_percentage_spec(train_spec, n_rows) if train_spec else None
test_range = self._parse_percentage_spec(test_spec, n_rows) if test_spec else None
predict_range = self._parse_percentage_spec(predict_spec, n_rows) if predict_spec else None
# Assign indices
train_indices = []
test_indices = []
predict_indices = []
if train_range:
start, end = train_range
train_indices = indices[start:end].tolist()
if test_range:
start, end = test_range
test_indices = indices[start:end].tolist()
if predict_range:
start, end = predict_range
predict_indices = indices[start:end].tolist()
# Validate no overlap
self._validate_no_overlap(train_indices, test_indices, predict_indices)
return PartitionResult(
train_indices=train_indices,
test_indices=test_indices,
predict_indices=predict_indices,
train_data=df.iloc[train_indices].copy() if train_indices else None,
test_data=df.iloc[test_indices].copy() if test_indices else None,
predict_data=df.iloc[predict_indices].copy() if predict_indices else None,
)
def _parse_percentage_spec(
self,
spec: str,
n_rows: int,
) -> tuple:
"""Parse a percentage specification into (start_idx, end_idx).
Formats:
"80%" -> (0, int(n_rows * 0.8))
"0:80%" -> (0, int(n_rows * 0.8))
"80%:100%" -> (int(n_rows * 0.8), n_rows)
"20%:80%" -> (int(n_rows * 0.2), int(n_rows * 0.8))
"""
if not isinstance(spec, str):
raise PartitionError(f"Percentage spec must be a string, got: {type(spec)}")
spec = spec.strip()
# Handle simple percentage like "80%"
if ":" not in spec:
if not spec.endswith("%"):
raise PartitionError(f"Invalid percentage format: '{spec}'")
pct = float(spec.rstrip("%")) / 100.0
return (0, int(n_rows * pct))
# Handle range like "80%:100%" or "0:80%"
parts = spec.split(":")
if len(parts) != 2:
raise PartitionError(f"Invalid percentage range format: '{spec}'")
start_str, end_str = parts
# Parse start
if start_str.strip().endswith("%"):
start_pct = float(start_str.strip().rstrip("%")) / 100.0
start_idx = int(n_rows * start_pct)
elif start_str.strip():
start_idx = int(start_str.strip())
else:
start_idx = 0
# Parse end
if end_str.strip().endswith("%"):
end_pct = float(end_str.strip().rstrip("%")) / 100.0
end_idx = int(n_rows * end_pct)
elif end_str.strip():
end_idx = int(end_str.strip())
else:
end_idx = n_rows
return (start_idx, end_idx)
def _assign_by_indices(
self,
df: pd.DataFrame,
partition: Dict[str, Any],
) -> PartitionResult:
"""Assign based on explicit index lists.
Args:
df: The DataFrame to partition.
partition: Dict with keys:
- train: List of indices for training
- test: List of indices for testing
- predict: List of indices for prediction
"""
n_rows = len(df)
train_indices = partition.get("train", [])
test_indices = partition.get("test", [])
predict_indices = partition.get("predict", [])
# Validate and normalize indices
train_indices = self._validate_indices(train_indices, n_rows, "train")
test_indices = self._validate_indices(test_indices, n_rows, "test")
predict_indices = self._validate_indices(predict_indices, n_rows, "predict")
# Validate no overlap
self._validate_no_overlap(train_indices, test_indices, predict_indices)
return PartitionResult(
train_indices=train_indices,
test_indices=test_indices,
predict_indices=predict_indices,
train_data=df.iloc[train_indices].copy() if train_indices else None,
test_data=df.iloc[test_indices].copy() if test_indices else None,
predict_data=df.iloc[predict_indices].copy() if predict_indices else None,
)
def _assign_by_index_file(
self,
df: pd.DataFrame,
partition: Dict[str, Any],
) -> PartitionResult:
"""Assign based on index files.
Args:
df: The DataFrame to partition.
partition: Dict with keys:
- train_file: Path to file with training indices
- test_file: Path to file with test indices
- predict_file: Path to file with predict indices
"""
train_file = partition.get("train_file")
test_file = partition.get("test_file")
predict_file = partition.get("predict_file")
train_indices = self._load_indices_from_file(train_file) if train_file else []
test_indices = self._load_indices_from_file(test_file) if test_file else []
predict_indices = self._load_indices_from_file(predict_file) if predict_file else []
n_rows = len(df)
# Validate indices
train_indices = self._validate_indices(train_indices, n_rows, "train")
test_indices = self._validate_indices(test_indices, n_rows, "test")
predict_indices = self._validate_indices(predict_indices, n_rows, "predict")
# Validate no overlap
self._validate_no_overlap(train_indices, test_indices, predict_indices)
return PartitionResult(
train_indices=train_indices,
test_indices=test_indices,
predict_indices=predict_indices,
train_data=df.iloc[train_indices].copy() if train_indices else None,
test_data=df.iloc[test_indices].copy() if test_indices else None,
predict_data=df.iloc[predict_indices].copy() if predict_indices else None,
)
def _load_indices_from_file(self, file_path: str) -> List[int]:
"""Load indices from a file.
Supports:
- Text files with one index per line
- CSV files with indices in first column
- JSON files with list of indices
"""
path = Path(file_path)
if self.base_path and not path.is_absolute():
path = self.base_path / path
if not path.exists():
raise PartitionError(f"Index file not found: {path}")
suffix = path.suffix.lower()
try:
if suffix == ".json":
import json
with open(path, "r") as f:
indices = json.load(f)
if not isinstance(indices, list):
raise PartitionError(
f"JSON file {path} must contain a list of indices."
)
return [int(i) for i in indices]
elif suffix in (".yaml", ".yml"):
import yaml
with open(path, "r") as f:
indices = yaml.safe_load(f)
if not isinstance(indices, list):
raise PartitionError(
f"YAML file {path} must contain a list of indices."
)
return [int(i) for i in indices]
elif suffix == ".csv":
df = pd.read_csv(path, header=None)
return df.iloc[:, 0].astype(int).tolist()
else:
# Assume text file with one index per line
with open(path, "r") as f:
lines = f.readlines()
return [int(line.strip()) for line in lines if line.strip()]
except Exception as e:
raise PartitionError(f"Failed to load indices from {path}: {e}")
def _validate_indices(
self,
indices: List[int],
n_rows: int,
partition_name: str,
) -> List[int]:
"""Validate and normalize index list."""
if not isinstance(indices, (list, tuple)):
indices = [indices]
validated = []
for idx in indices:
idx = int(idx)
# Handle negative indices
if idx < 0:
idx = n_rows + idx
if idx < 0 or idx >= n_rows:
raise PartitionError(
f"Index {idx} in '{partition_name}' out of range. "
f"DataFrame has {n_rows} rows (0-{n_rows - 1})."
)
validated.append(idx)
return validated
def _validate_no_overlap(
self,
train_indices: List[int],
test_indices: List[int],
predict_indices: List[int],
) -> None:
"""Validate that partition indices don't overlap."""
train_set = set(train_indices)
test_set = set(test_indices)
predict_set = set(predict_indices)
# Check train-test overlap
train_test_overlap = train_set & test_set
if train_test_overlap:
raise PartitionError(
f"Train and test partitions overlap at indices: "
f"{sorted(list(train_test_overlap))[:10]}"
)
# Check train-predict overlap
train_predict_overlap = train_set & predict_set
if train_predict_overlap:
raise PartitionError(
f"Train and predict partitions overlap at indices: "
f"{sorted(list(train_predict_overlap))[:10]}"
)
# Check test-predict overlap
test_predict_overlap = test_set & predict_set
if test_predict_overlap:
raise PartitionError(
f"Test and predict partitions overlap at indices: "
f"{sorted(list(test_predict_overlap))[:10]}"
)
def _stratified_shuffle(
self,
df: pd.DataFrame,
stratify_column: str,
random_state: Optional[int],
) -> np.ndarray:
"""Create stratified shuffled indices.
Returns indices shuffled such that when split sequentially,
each split maintains the original class proportions.
This works by:
1. Shuffling indices within each stratum
2. Interleaving the strata proportionally
Example: If we have 80 samples of class 0 and 20 of class 1,
and we want 80% train, we should get:
- Train: 64 of class 0 + 16 of class 1 (maintains 80/20 ratio)
- Test: 16 of class 0 + 4 of class 1 (maintains 80/20 ratio)
"""
rng = np.random.RandomState(random_state)
# Group by stratify column
groups = df.groupby(stratify_column)
n_groups = len(groups)
# Collect and shuffle indices for each group
group_indices_list = []
for name, group in groups:
group_positions = [
df.index.get_loc(idx)
for idx in group.index.tolist()
]
rng.shuffle(group_positions)
group_indices_list.append(group_positions)
# Interleave indices from each group proportionally
# This ensures that any contiguous slice has similar proportions
result_indices = []
n_rows = len(df)
# Use round-robin with proportional allocation
group_positions = [0] * n_groups # Track position in each group
for i in range(n_rows):
# Find which group to pick from based on proportion
# We pick from the group that's most "behind" its expected proportion
best_group = -1
best_deficit = -float('inf')
for g_idx, indices in enumerate(group_indices_list):
if group_positions[g_idx] >= len(indices):
continue # This group is exhausted
expected = (i + 1) * len(indices) / n_rows
actual = group_positions[g_idx]
deficit = expected - actual
if deficit > best_deficit:
best_deficit = deficit
best_group = g_idx
if best_group >= 0:
result_indices.append(group_indices_list[best_group][group_positions[best_group]])
group_positions[best_group] += 1
return np.array(result_indices)
[docs]
def concatenate_partitions(
self,
results: Sequence[PartitionResult],
) -> PartitionResult:
"""Concatenate multiple partition results.
Useful when combining multiple files with the same partition.
Indices are adjusted to account for concatenation order.
Args:
results: Sequence of PartitionResult objects.
Returns:
Combined PartitionResult.
"""
if not results:
return PartitionResult()
combined_train = []
combined_test = []
combined_predict = []
train_dfs = []
test_dfs = []
predict_dfs = []
offset = 0
for result in results:
# Adjust indices by offset
if result.train_indices:
combined_train.extend([i + offset for i in result.train_indices])
if result.train_data is not None:
train_dfs.append(result.train_data)
if result.test_indices:
combined_test.extend([i + offset for i in result.test_indices])
if result.test_data is not None:
test_dfs.append(result.test_data)
if result.predict_indices:
combined_predict.extend([i + offset for i in result.predict_indices])
if result.predict_data is not None:
predict_dfs.append(result.predict_data)
# Update offset (use max index from this result)
all_indices = (
result.train_indices +
result.test_indices +
result.predict_indices
)
if all_indices:
offset += max(all_indices) + 1
return PartitionResult(
train_indices=combined_train,
test_indices=combined_test,
predict_indices=combined_predict,
train_data=pd.concat(train_dfs, ignore_index=True) if train_dfs else None,
test_data=pd.concat(test_dfs, ignore_index=True) if test_dfs else None,
predict_data=pd.concat(predict_dfs, ignore_index=True) if predict_dfs else None,
)