Source code for nirs4all.data.indexer

from typing import Dict, List, Union, Any, Optional, overload, Mapping
import numpy as np
import polars as pl

from nirs4all.data.types import Selector, SampleIndices, PartitionType, ProcessingList, IndexDict
from nirs4all.data._indexer import (
    IndexStore,
    QueryBuilder,
    SampleManager,
    AugmentationTracker,
    ProcessingManager,
    ParameterNormalizer,
)


[docs] class Indexer: """ Index manager for samples used in ML/DL pipelines. Optimizes contiguous access and manages filtering. This class is designed to retrieve data during ML pipelines. For example, it can be used to get all test samples from branch 2, including augmented samples, for specific processings such as ["raw", "savgol", "gaussian"]. The Indexer uses a component-based architecture for maintainability: - IndexStore: DataFrame storage and queries - QueryBuilder: Selector to Polars expression conversion - SampleManager: ID generation - AugmentationTracker: Origin/augmented relationships - ProcessingManager: Processing list operations - ParameterNormalizer: Input validation """ def __init__(self): # Initialize components self._store = IndexStore() self._query_builder = QueryBuilder(valid_columns=self._store.columns) self._sample_manager = SampleManager(self._store) self._augmentation_tracker = AugmentationTracker(self._store, self._query_builder) self._processing_manager = ProcessingManager(self._store) self._parameter_normalizer = ParameterNormalizer(default_processings=["raw"]) @property def df(self) -> pl.DataFrame: """ Get the underlying DataFrame for backward compatibility. Returns: pl.DataFrame: The complete index DataFrame. Note: Direct DataFrame access is provided for backward compatibility. Prefer using indexer methods when possible. """ return self._store.df @property def default_values(self) -> Dict[str, Any]: """ Get default values for backward compatibility. Returns: Dict[str, Any]: Default values used when parameters are None. """ return { "partition": "train", "processings": ["raw"], } def _ensure_selector_dict(self, selector: Any) -> Dict[str, Any]: """Ensure selector is a dictionary.""" if selector is None: return {} # Handle ExecutionContext (duck typing) if hasattr(selector, "selector") and hasattr(selector, "state"): return dict(selector.selector) if isinstance(selector, Mapping): return dict(selector) return {} def _apply_filters(self, selector: Selector) -> pl.DataFrame: """Apply selector filters and return filtered DataFrame.""" selector = self._ensure_selector_dict(selector) condition = self._query_builder.build(selector, exclude_columns=["processings"]) return self._store.query(condition) def _build_filter_condition(self, selector: Selector) -> pl.Expr: """Build a Polars filter expression from selector.""" selector = self._ensure_selector_dict(selector) return self._query_builder.build(selector, exclude_columns=["processings"])
[docs] def x_indices(self, selector: Selector, include_augmented: bool = True, include_excluded: bool = False) -> np.ndarray: """ Get sample indices with optional augmented sample aggregation. This method implements two-phase selection to prevent data leakage: 1. Phase 1: Get base samples (sample == origin) 2. Phase 2: Get augmented versions of those base samples Args: selector: Filter criteria dictionary. Supported keys: - partition: "train"|"test"|"val" or list - group: int or list of ints - branch: int or list of ints - augmentation: str, list, or None for null check - Any other indexed columns include_augmented: If True, include augmented versions of selected samples. If False, return only base samples (sample == origin). Default True for backward compatibility. include_excluded: If True, include samples marked as excluded. If False (default), exclude samples marked as excluded=True. Use True for diagnostics, reporting, or viewing excluded samples. Returns: np.ndarray: Array of sample indices (dtype: np.int32). When include_augmented=True, includes base samples and their augmented versions. When False, only base samples where sample == origin. Raises: KeyError: If selector contains invalid column names. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, partition="train") >>> indexer.augment_rows([0, 1], 2, "flip") >>> >>> # Get all train samples (base + augmented) >>> all_train = indexer.x_indices({"partition": "train"}) >>> # Returns: [0, 1, 2, 3, 4, 5, 6, 7, 8] (5 base + 4 augmented) >>> >>> # Get only base train samples >>> base_train = indexer.x_indices({"partition": "train"}, include_augmented=False) >>> # Returns: [0, 1, 2, 3, 4] (5 base only) >>> >>> # Mark sample as excluded and filter it >>> indexer.mark_excluded([0], reason="outlier") >>> filtered = indexer.x_indices({"partition": "train"}) >>> # Returns: [1, 2, 3, 4, ...] (sample 0 and its augmentations excluded) >>> >>> # Include excluded samples (for diagnostics) >>> all_samples = indexer.x_indices({"partition": "train"}, include_excluded=True) Note: The two-phase selection ensures that augmented samples from other partitions are NOT included, preventing data leakage in cross-validation scenarios. """ # Build the exclusion filter excluded_filter = self._query_builder.build_excluded_filter(include_excluded) if not include_augmented: # Simple case: filter for base samples only condition = self._build_filter_condition(selector) base_condition = condition & self._query_builder.build_base_samples_filter() & excluded_filter filtered_df = self._store.query(base_condition) return filtered_df.select(pl.col("sample")).to_series().to_numpy().astype(np.int32) # Two-phase selection using augmentation tracker # Pass both the base condition AND the exclusion filter condition = self._build_filter_condition(selector) & excluded_filter return self._augmentation_tracker.get_all_samples_with_augmentations( condition, additional_filter=excluded_filter )
[docs] def y_indices(self, selector: Selector, include_augmented: bool = True, include_excluded: bool = False) -> np.ndarray: """ Get y indices for samples. Returns origin indices for y-value lookup. For augmented samples, this method maps them to their base samples (origins) since y-values only exist for base samples. This enables proper target retrieval when working with augmented data. Args: selector: Filter criteria dictionary. Same format as x_indices(). See x_indices() for supported keys. include_augmented: If True (default), include augmented samples mapped to their origins. If False, return only base sample origins (sample == origin). Default True for backward compatibility with original behavior. include_excluded: If True, include samples marked as excluded. If False (default), exclude samples marked as excluded=True. Use True for diagnostics, reporting, or viewing excluded samples. Returns: np.ndarray: Array of origin sample indices for y-value lookup (dtype: np.int32). When include_augmented=True (default), augmented samples are included and each is mapped to its origin. When False, only base samples are returned (sample == origin). Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, partition="train") >>> indexer.augment_rows([0, 1], 2, "flip") >>> >>> # Get origins for all train samples (base + augmented) >>> y_idx = indexer.y_indices({"partition": "train"}) >>> # Returns: [0, 1, 2, 3, 4, 0, 0, 1, 1] >>> # (5 base origins + 4 augmented mapped to origins 0, 0, 1, 1) >>> >>> # Use with targets >>> targets = np.array([10, 20, 30, 40, 50]) # 5 base samples >>> x_idx = indexer.x_indices({"partition": "train"}) >>> y_idx = indexer.y_indices({"partition": "train"}) >>> X = all_spectra[x_idx] # Get spectra (includes augmented) >>> y = targets[y_idx] # Get targets (augmented samples use origin's target) >>> >>> # Get only base sample origins >>> base_origins = indexer.y_indices({"partition": "train"}, include_augmented=False) >>> # Returns: [0, 1, 2, 3, 4] >>> >>> # Exclude filtered samples >>> indexer.mark_excluded([0], reason="outlier") >>> filtered_y = indexer.y_indices({"partition": "train"}) >>> # Sample 0 and its augmentations excluded from result Note: The length and order of y_indices() output always corresponds to x_indices() output with the same selector and include_augmented parameters. This ensures X and y arrays are properly aligned for training. """ # Build the exclusion filter excluded_filter = self._query_builder.build_excluded_filter(include_excluded) filtered_df = self._apply_filters(selector) if selector else self._store.df # Apply exclusion filter filtered_df = filtered_df.filter(excluded_filter) if not include_augmented: # Return only base sample origins (sample == origin) base_condition = self._query_builder.build_base_samples_filter() return filtered_df.filter(base_condition).select(pl.col("origin")).to_series().to_numpy().astype(np.int32) # Include augmented samples: all origins are returned (with augmented samples mapped) return filtered_df.select(pl.col("origin")).to_series().to_numpy().astype(np.int32)
[docs] def get_augmented_for_origins(self, origin_samples: List[int]) -> np.ndarray: """ Get all augmented samples for given origin sample IDs. This method is used to retrieve augmented versions of base samples, enabling two-phase selection that prevents data leakage across CV folds. Args: origin_samples: List of origin sample IDs to find augmented versions for. Can be empty list. Returns: np.ndarray: Array of augmented sample IDs (dtype: np.int32). Only includes samples where origin is in origin_samples AND sample != origin (actual augmented samples, not base samples). Examples: >>> indexer = Indexer() >>> indexer.add_samples(3, partition="train") >>> indexer.augment_rows([0, 1], 2, "flip") >>> >>> # Get base samples >>> base_samples = indexer.x_indices({"partition": "train"}, include_augmented=False) >>> # base_samples: [0, 1, 2] >>> >>> # Get their augmented versions >>> augmented = indexer.get_augmented_for_origins(base_samples.tolist()) >>> # augmented: [3, 4, 5, 6] (2 augmented each for samples 0 and 1) >>> >>> # Combine for full dataset >>> all_samples = np.concatenate([base_samples, augmented]) >>> # all_samples: [0, 1, 2, 3, 4, 5, 6] Note: This method does not filter by partition, group, or other criteria. It returns ALL augmented samples for the given origins, regardless of their attributes. Use x_indices() for filtered retrieval with automatic augmentation handling. """ return self._augmentation_tracker.get_augmented_for_origins(origin_samples)
[docs] def get_origin_for_sample(self, sample_id: int) -> Optional[int]: """ Get origin sample ID for a given sample. With the current design, all samples have origin set: - Base samples: origin == sample (self-referencing) - Augmented samples: origin != sample (references base sample) Args: sample_id: Sample ID to look up. Returns: Optional[int]: Origin sample ID, or None if sample not found in index. Examples: >>> indexer = Indexer() >>> indexer.add_samples(2, partition="train") >>> indexer.augment_rows([0], 1, "flip") >>> >>> # For augmented sample >>> origin = indexer.get_origin_for_sample(2) # Sample 2 is augmentation of 0 >>> print(origin) # 0 >>> >>> # For base sample >>> origin = indexer.get_origin_for_sample(0) # Sample 0 is base >>> print(origin) # 0 (self-referencing) >>> >>> # For non-existent sample >>> origin = indexer.get_origin_for_sample(999) >>> print(origin) # None Note: This is a single-sample lookup. For batch operations, use y_indices() which is more efficient for retrieving origins for multiple samples. """ return self._augmentation_tracker.get_origin_for_sample(sample_id)
[docs] def replace_processings(self, source_processings: List[str], new_processings: List[str]) -> None: """ Replace processing names across all samples. Creates a mapping from old to new processing names and applies it to all processing lists in the index. Args: source_processings: List of existing processing names to replace. new_processings: List of new processing names to set. Must have same length as source_processings. Raises: ValueError: If source_processings and new_processings have different lengths. ValueError: If source_processings or new_processings is empty. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, processings=["raw", "old_msc", "savgol"]) >>> >>> # Replace single processing >>> indexer.replace_processings(["old_msc"], ["msc"]) >>> # Now all samples have ["raw", "msc", "savgol"] >>> >>> # Replace multiple processings >>> indexer.replace_processings( ... ["raw", "savgol"], ... ["raw_v2", "savgol_v2"] ... ) >>> # Now all samples have ["raw_v2", "msc", "savgol_v2"] Note: - Operates on ALL rows in the index - Non-matched processings are left unchanged - Case-sensitive matching - Use this method when renaming processings after pipeline changes """ self._processing_manager.replace_processings(source_processings, new_processings)
[docs] def reset_processings(self, new_processings: List[str]) -> None: """ Reset processing names for all samples to a new list. This replaces the entire processing list for every sample with the provided list. Used when resetting feature storage (e.g. after merge). Args: new_processings: List of new processing names. Raises: ValueError: If new_processings is empty. """ self._processing_manager.reset_processings(new_processings)
[docs] def add_processings(self, new_processings: List[str]) -> None: """ Append processing names to all existing processing lists. Adds new processings to the end of each sample's processing list. This is useful when applying additional transformations to all data. Args: new_processings: List of new processing names to add to existing lists. Raises: ValueError: If new_processings is empty. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, processings=["raw", "msc"]) >>> >>> # Add single processing >>> indexer.add_processings(["normalize"]) >>> # All samples now have ["raw", "msc", "normalize"] >>> >>> # Add multiple processings >>> indexer.add_processings(["scale", "center"]) >>> # All samples now have ["raw", "msc", "normalize", "scale", "center"] Note: - Operates on ALL rows in the index - Appends to the end of each processing list - Does not check for duplicates (allows intentional reprocessing) - Use this method when adding pipeline steps to existing data """ self._processing_manager.add_processings(new_processings)
def _normalize_indices(self, indices: SampleIndices, count: int, param_name: str) -> List[int]: """Normalize various index formats to a list of integers (internal helper).""" return self._parameter_normalizer.normalize_indices(indices, count, param_name) def _normalize_single_or_list(self, value: Union[Any, List[Any]], count: int, param_name: str, allow_none: bool = False) -> List[Any]: """Normalize single value or list to a list of specified length (internal helper).""" return self._parameter_normalizer.normalize_single_or_list(value, count, param_name, allow_none) def _prepare_processings(self, processings: Union[ProcessingList, List[ProcessingList], str, List[str], None], count: int) -> List[List[str]]: """Prepare processings list with proper validation (internal helper).""" return self._parameter_normalizer.prepare_processings(processings, count) def _convert_indexdict_to_params(self, index_dict: IndexDict, count: int) -> Dict[str, Any]: """Convert IndexDict to method parameters (internal helper).""" return self._parameter_normalizer.convert_indexdict_to_params(index_dict, count) def _append(self, count: int, *, partition: PartitionType = "train", sample_indices: Optional[SampleIndices] = None, origin_indices: Optional[SampleIndices] = None, group: Optional[Union[int, List[int]]] = None, branch: Optional[Union[int, List[int]]] = None, processings: Union[ProcessingList, List[ProcessingList], str, List[str], None] = None, augmentation: Optional[Union[str, List[str]]] = None, **overrides) -> List[int]: """ Core method to append samples to the indexer (internal). Args: count: Number of samples to add partition: Data partition ("train", "test", "val") sample_indices: Specific sample IDs to use. If None, auto-increment origin_indices: Original sample IDs for augmented samples group: Group ID(s) - single value or list of values branch: Branch ID(s) - single value or list of values processings: Processing steps - single list or list of lists augmentation: Augmentation type(s) - single value or list **overrides: Additional column overrides Returns: List of sample indices that were added """ if count <= 0: return [] # Generate row and sample IDs using SampleManager row_ids = self._sample_manager.generate_row_ids(count) if sample_indices is None: sample_ids = self._sample_manager.generate_sample_ids(count) if origin_indices is None: # Base samples: origin = sample (self-referencing) origins = sample_ids.copy() else: origins = self._normalize_indices(origin_indices, count, "origin_indices") else: sample_ids = self._normalize_indices(sample_indices, count, "sample_indices") if origin_indices is None: # Base samples: origin = sample (self-referencing) origins = [int(x) for x in sample_ids] else: origins = self._normalize_indices(origin_indices, count, "origin_indices") # Normalize column values groups = self._normalize_single_or_list(group, count, "group") branches = self._normalize_single_or_list(branch, count, "branch") processings_list = self._prepare_processings(processings, count) # Now returns List[List[str]] augmentations = self._normalize_single_or_list(augmentation, count, "augmentation", allow_none=True) # Handle additional overrides additional_cols = {} for col, value in overrides.items(): if col in self._store.columns and col not in ["row", "sample", "origin", "partition", "group", "branch", "processings", "augmentation"]: if isinstance(value, (list, np.ndarray)): if len(value) != count: raise ValueError(f"{col} length ({len(value)}) must match count ({count})") additional_cols[col] = list(value) else: additional_cols[col] = [value] * count # Create new DataFrame with native List type for processings new_data = { "row": pl.Series(row_ids, dtype=pl.Int32), "sample": pl.Series(sample_ids, dtype=pl.Int32), "origin": pl.Series(origins, dtype=pl.Int32), "partition": pl.Series([partition] * count, dtype=pl.Categorical), "group": pl.Series(groups, dtype=pl.Int8), "branch": pl.Series(branches, dtype=pl.Int8), "processings": pl.Series(processings_list, dtype=pl.List(pl.Utf8)), # Native list! "augmentation": pl.Series(augmentations, dtype=pl.Categorical), "excluded": pl.Series([False] * count, dtype=pl.Boolean), # Default: not excluded "exclusion_reason": pl.Series([None] * count, dtype=pl.Utf8), # Default: no reason } # Add additional columns with proper casting for col, values in additional_cols.items(): expected_dtype = self._store.schema[col] new_data[col] = pl.Series(values, dtype=expected_dtype) # Append to store self._store.append(new_data) return sample_ids
[docs] def add_samples( self, count: int, partition: PartitionType = "train", sample_indices: Optional[SampleIndices] = None, origin_indices: Optional[SampleIndices] = None, group: Optional[Union[int, List[int]]] = None, branch: Optional[Union[int, List[int]]] = None, processings: Union[ProcessingList, List[ProcessingList], None] = None, augmentation: Optional[Union[str, List[str]]] = None, **kwargs ) -> List[int]: """ Add multiple samples to the indexer efficiently. This is the primary method for registering samples in the index. Samples can be base samples or augmented samples, with flexible parameter specification. Args: count: Number of samples to add. Must be positive. partition: Data partition ("train", "test", "val"). Default "train". sample_indices: Specific sample IDs to use. If None, auto-increment from current max. Can be: - int: Single ID repeated for all samples - List[int]: One ID per sample (length must match count) - np.ndarray: One ID per sample (length must match count) origin_indices: Original sample IDs for augmented samples. If None, samples are treated as base samples (origin = sample). Same format options as sample_indices. group: Group ID(s) for sample categorization. Can be: - int: Single group for all samples - List[int]: One group per sample (length must match count) - None: No group assignment branch: Pipeline branch ID(s). Same format as group. processings: Processing transformations applied. Can be: - None: Uses default ["raw"] - List[str]: Single list for all samples (e.g., ["raw", "msc"]) - List[List[str]]: One list per sample (length must match count) augmentation: Augmentation type(s). Same format as group, but allows None values. **kwargs: Additional column values. Must match count if list/array. Returns: List[int]: List of sample IDs that were added. Length equals count. Raises: ValueError: If count <= 0, or if list/array parameter lengths don't match count. TypeError: If parameter types are invalid. Examples: >>> indexer = Indexer() >>> >>> # Add 5 base train samples with default settings >>> ids = indexer.add_samples(5) >>> # ids: [0, 1, 2, 3, 4] >>> >>> # Add test samples with specific processings >>> test_ids = indexer.add_samples( ... 3, ... partition="test", ... processings=["raw", "msc", "savgol"] ... ) >>> >>> # Add samples with different groups >>> grouped_ids = indexer.add_samples( ... 4, ... partition="train", ... group=[1, 1, 2, 2], ... processings=["raw"] ... ) >>> >>> # Add augmented samples (references existing samples as origins) >>> aug_ids = indexer.add_samples( ... 2, ... partition="train", ... origin_indices=[0, 1], # Augmentations of samples 0 and 1 ... augmentation="flip" ... ) Note: - Auto-incrementing sample IDs start from 0 or next available ID - Base samples have origin == sample (self-referencing) - Augmented samples have origin != sample (references base sample) - Single values are broadcast to all samples - Lists/arrays must match count exactly """ return self._append( count, partition=partition, sample_indices=sample_indices, origin_indices=origin_indices, group=group, branch=branch, processings=processings, augmentation=augmentation, **kwargs )
[docs] def add_samples_dict( self, count: int, indices: Optional[IndexDict] = None, **kwargs ) -> List[int]: """ Add multiple samples using dictionary-based parameter specification. This method provides a cleaner API for specifying sample parameters using a dictionary, similar to the filtering API pattern. Args: count: Number of samples to add indices: Dictionary containing column specifications { "partition": "train|test|val", "sample": [list of sample IDs] or single ID, "origin": [list of origin IDs] or single ID, "group": [list of groups] or single group, "branch": [list of branches] or single branch, "processings": processing configuration, "augmentation": augmentation type, ... (any other column) } **kwargs: Additional column overrides (take precedence over indices) Returns: List of sample indices that were added Example: # Add samples with dictionary specification indexer.add_samples_dict(3, { "partition": "train", "group": [1, 2, 1], "processings": ["raw", "msc"] }) """ if indices is None: indices = {} params = self._convert_indexdict_to_params(indices, count) params.update(kwargs) return self._append(count, **params)
[docs] def add_rows(self, n_rows: int, new_indices: Optional[Dict[str, Any]] = None) -> List[int]: """Add rows to the indexer with optional column overrides.""" if n_rows <= 0: return [] new_indices = new_indices or {} # Extract arguments for _append kwargs = {} # Handle special mappings if "sample" in new_indices: kwargs["sample_indices"] = new_indices["sample"] if "origin" in new_indices: kwargs["origin_indices"] = new_indices["origin"] elif "sample" not in new_indices: # For add_rows, default origin to sample indices when not explicitly set next_sample_idx = self.next_sample_index() kwargs["origin_indices"] = list(range(next_sample_idx, next_sample_idx + n_rows)) # Handle direct mappings for key in ["partition", "group", "branch", "processings", "augmentation"]: if key in new_indices: kwargs[key] = new_indices[key] # Handle any other overrides for key, value in new_indices.items(): if key not in ["sample", "origin", "partition", "group", "branch", "processings", "augmentation"]: kwargs[key] = value return self._append(n_rows, **kwargs)
[docs] def add_rows_dict( self, n_rows: int, indices: IndexDict, **kwargs ) -> List[int]: """ Add rows using dictionary-based parameter specification. This method provides a cleaner API for specifying row parameters using a dictionary, similar to the filtering API pattern. Args: n_rows: Number of rows to add indices: Dictionary containing column specifications { "partition": "train|test|val", "sample": [list of sample IDs] or single ID, "origin": [list of origin IDs] or single ID, "group": [list of groups] or single group, "branch": [list of branches] or single branch, "processings": processing configuration, "augmentation": augmentation type, ... (any other column) } **kwargs: Additional column overrides (take precedence over indices) Returns: List of sample indices that were added Example: # Add rows with dictionary specification indexer.add_rows_dict(2, { "partition": "val", "sample": [100, 101], "group": 5 }) """ if n_rows <= 0: return [] params = self._convert_indexdict_to_params(indices, n_rows) params.update(kwargs) # kwargs take precedence return self._append(n_rows, **params)
[docs] def register_samples(self, count: int, partition: PartitionType = "train") -> List[int]: """Register samples using the unified _append method.""" return self._append(count, partition=partition)
[docs] def register_samples_dict( self, count: int, indices: IndexDict, **kwargs ) -> List[int]: """ Register samples using dictionary-based parameter specification. Args: count: Number of samples to register indices: Dictionary containing column specifications **kwargs: Additional column overrides (take precedence over indices) Returns: List of sample indices that were registered Example: indexer.register_samples_dict(5, {"partition": "test", "group": 2}) """ params = self._convert_indexdict_to_params(indices, count) params.update(kwargs) # kwargs take precedence return self._append(count, **params)
[docs] def update_by_filter(self, selector: Selector, updates: Dict[str, Any]) -> None: """ Update rows matching a selector filter. Args: selector: Filter criteria dictionary (same format as x_indices). updates: Dictionary of column:value pairs to update. Example: >>> indexer.update_by_filter({"partition": "train", "group": 1}, {"branch": 2}) """ condition = self._build_filter_condition(selector) self._store.update_by_condition(condition, updates)
[docs] def update_by_indices(self, sample_indices: SampleIndices, updates: Dict[str, Any]) -> None: """ Update rows by sample indices. Args: sample_indices: Sample IDs to update (int, list, or array). updates: Dictionary of column:value pairs to update. Example: >>> indexer.update_by_indices([0, 1, 2], {"group": 5}) """ count = len(sample_indices) if isinstance(sample_indices, (list, np.ndarray)) else 1 sample_ids = self._normalize_indices(sample_indices, count, "sample_indices") condition = self._query_builder.build_sample_filter(sample_ids) self._store.update_by_condition(condition, updates)
[docs] def next_row_index(self) -> int: """ Get the next available row index. Returns: int: Next row index (max + 1, or 0 if empty). Example: >>> next_idx = indexer.next_row_index() """ return self._sample_manager.next_row_id()
[docs] def next_sample_index(self) -> int: """ Get the next available sample index. Returns: int: Next sample index (max + 1, or 0 if empty). Example: >>> next_idx = indexer.next_sample_index() """ return self._sample_manager.next_sample_id()
[docs] def get_column_values(self, col: str, filters: Optional[Dict[str, Any]] = None) -> List[Any]: """ Get column values, optionally filtered. Args: col: Column name to retrieve. filters: Optional selector dictionary for filtering. Returns: List[Any]: Column values. Example: >>> partitions = indexer.get_column_values("partition") >>> train_groups = indexer.get_column_values("group", {"partition": "train"}) """ condition = self._build_filter_condition(filters) if filters else None return self._store.get_column(col, condition)
[docs] def uniques(self, col: str) -> List[Any]: """ Get unique values in a column. Args: col: Column name. Returns: List[Any]: Unique values in the column. Example: >>> unique_partitions = indexer.uniques("partition") """ return self._store.get_unique(col)
[docs] def augment_rows(self, samples: List[int], count: Union[int, List[int]], augmentation_id: str) -> List[int]: """ Create augmented samples based on existing samples. This method creates new augmented samples that reference existing base samples as their origins. The augmented samples inherit all attributes (partition, group, branch, processings) from their origin samples. Args: samples: List of sample IDs to augment. Must exist in the index. count: Number of augmentations per sample. Can be: - int: Same count for all samples - List[int]: One count per sample (length must match samples) augmentation_id: String identifier for the augmentation type (e.g., "flip", "rotate", "noise"). Returns: List[int]: List of new sample IDs for the augmented samples. Raises: ValueError: If samples list is empty, if count list length doesn't match samples length, or if any sample IDs are not found. Examples: >>> indexer = Indexer() >>> base_ids = indexer.add_samples(3, partition="train", processings=["raw", "msc"]) >>> >>> # Create 2 augmentations for each base sample >>> aug_ids = indexer.augment_rows(base_ids, 2, "flip") >>> # aug_ids: [3, 4, 5, 6, 7, 8] (2 per sample) >>> >>> # Different counts per sample >>> aug_ids2 = indexer.augment_rows([0, 1], [1, 3], "rotate") >>> # aug_ids2: [9, 10, 11, 12] (1 for sample 0, 3 for sample 1) >>> >>> # Verify augmented samples reference their origins >>> origin = indexer.get_origin_for_sample(aug_ids[0]) >>> print(origin) # base_ids[0] Note: - Augmented samples inherit partition, group, branch, and processings from origins - origin field is set to the base sample ID - augmentation field is set to augmentation_id - Useful for data augmentation in ML pipelines (flips, rotations, noise, etc.) """ if not samples: return [] # Normalize count to list if isinstance(count, int): count_list = [count] * len(samples) else: count_list = list(count) if len(count_list) != len(samples): raise ValueError("count must be an int or a list with the same length as samples") total_augmentations = sum(count_list) if total_augmentations == 0: return [] # Get sample data for the samples to augment sample_filter = self._query_builder.build_sample_filter(samples) filtered_df = self._store.query(sample_filter).sort("sample") if len(filtered_df) != len(samples): missing = set(samples) - set(filtered_df["sample"].to_list()) raise ValueError(f"Samples not found in indexer: {missing}") # Prepare data for augmented samples origin_indices = [] partitions = [] groups = [] branches = [] processings_list = [] for i, (sample_id, sample_count) in enumerate(zip(samples, count_list)): if sample_count <= 0: continue # Get the original sample data sample_row = filtered_df.filter(pl.col("sample") == sample_id).row(0, named=True) # Repeat data for each augmentation of this sample origin_indices.extend([sample_id] * sample_count) partitions.extend([sample_row["partition"]] * sample_count) groups.extend([sample_row["group"]] * sample_count) branches.extend([sample_row["branch"]] * sample_count) # Processings are now native lists, not strings processings_list.extend([sample_row["processings"]] * sample_count) # Create augmented samples using _append partition = partitions[0] if partitions else "train" augmented_ids = self._append( total_augmentations, partition=partition, origin_indices=origin_indices, group=groups, branch=branches, processings=processings_list, augmentation=augmentation_id ) return augmented_ids
[docs] def __repr__(self) -> str: """ String representation showing the DataFrame. Returns: str: String representation of the index DataFrame. """ return str(self._store.df)
[docs] def __str__(self) -> str: """ Human-readable summary of indexed samples. Returns: str: Summary showing sample counts by combination of attributes. """ df = self._store.df cols_to_include = [col for col in df.columns if col not in ["sample", "origin", "row"]] if not cols_to_include: return "No indexable columns found" if len(df) == 0: return "No rows found" # Group by all columns and count combinations = df.select(cols_to_include).group_by(cols_to_include).agg( pl.len().alias("count") ).sort("count", descending=True) # Format output summary = [] for row in combinations.to_dicts(): parts = [] for col in cols_to_include: value = row[col] if value is None: continue # Handle native list type for processings if col == "processings" and isinstance(value, list): parts.append(f"{col} - {value}") elif isinstance(value, str): parts.append(f'{col} - "{value}"') else: parts.append(f"{col} - {value}") if parts: combination_str = ", ".join(parts) count = row["count"] summary.append(f"{combination_str}: {count} samples") parts_str = "\n- ".join(summary) return f"Indexes:\n- {parts_str}"
# ==================== Sample Filtering Methods ====================
[docs] def mark_excluded( self, sample_indices: SampleIndices, reason: Optional[str] = None, cascade_to_augmented: bool = True ) -> int: """ Mark samples as excluded from training. Excluded samples are automatically filtered out from x_indices() and y_indices() calls unless include_excluded=True is explicitly passed. This provides a non-destructive way to remove outliers or corrupted samples from training. Args: sample_indices: Sample IDs to exclude. Can be: - int: Single sample ID - List[int]: List of sample IDs - np.ndarray: Array of sample IDs reason: Optional string describing why samples are excluded (e.g., "outlier", "corrupted", "low_quality"). cascade_to_augmented: If True (default), also exclude augmented samples derived from the specified base samples. This prevents data leakage from augmented versions of excluded samples. Returns: int: Number of samples marked as excluded. Raises: ValueError: If sample_indices is empty. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, partition="train") >>> indexer.augment_rows([0, 1], 2, "flip") >>> >>> # Mark sample 0 as excluded (outlier detection) >>> n_excluded = indexer.mark_excluded([0], reason="iqr_outlier") >>> # n_excluded: 3 (sample 0 + 2 augmented versions) >>> >>> # Verify exclusion >>> train_samples = indexer.x_indices({"partition": "train"}) >>> # Sample 0 and its augmentations no longer included >>> >>> # View excluded samples >>> excluded_df = indexer.get_excluded_samples() Note: - Exclusion is non-destructive: data remains in the indexer - Use mark_included() to reverse exclusion - Excluded samples can still be accessed via include_excluded=True - Cascade prevents data leakage from augmented versions """ count = len(sample_indices) if isinstance(sample_indices, (list, np.ndarray)) else 1 sample_ids = self._normalize_indices(sample_indices, count, "sample_indices") if not sample_ids: return 0 all_samples_to_exclude = set(sample_ids) # Cascade to augmented samples if requested if cascade_to_augmented: augmented = self._augmentation_tracker.get_augmented_for_origins(sample_ids) all_samples_to_exclude.update(augmented.tolist()) # Build condition and update condition = self._query_builder.build_sample_filter(list(all_samples_to_exclude)) updates = {"excluded": True} if reason is not None: updates["exclusion_reason"] = reason self._store.update_by_condition(condition, updates) return len(all_samples_to_exclude)
[docs] def mark_included( self, sample_indices: Optional[SampleIndices] = None, cascade_to_augmented: bool = True ) -> int: """ Remove exclusion flag from samples. This method reverses the effect of mark_excluded(), re-including samples in x_indices() and y_indices() results. Args: sample_indices: Sample IDs to include. Can be: - int: Single sample ID - List[int]: List of sample IDs - np.ndarray: Array of sample IDs - None: Include ALL currently excluded samples cascade_to_augmented: If True (default), also include augmented samples derived from the specified base samples. Returns: int: Number of samples marked as included. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, partition="train") >>> indexer.mark_excluded([0, 1], reason="outlier") >>> >>> # Re-include sample 0 >>> n_included = indexer.mark_included([0]) >>> # n_included: 1 >>> >>> # Re-include all excluded samples >>> n_included = indexer.mark_included() # No argument = all excluded Note: - Clears both the excluded flag and exclusion_reason - Useful for iterative filtering or correcting previous exclusions """ if sample_indices is None: # Include all currently excluded samples condition = pl.col("excluded") == True # noqa: E712 count = len(self._store.query(condition)) self._store.update_by_condition(condition, {"excluded": False, "exclusion_reason": None}) return count count = len(sample_indices) if isinstance(sample_indices, (list, np.ndarray)) else 1 sample_ids = self._normalize_indices(sample_indices, count, "sample_indices") if not sample_ids: return 0 all_samples_to_include = set(sample_ids) # Cascade to augmented samples if requested if cascade_to_augmented: augmented = self._augmentation_tracker.get_augmented_for_origins(sample_ids) all_samples_to_include.update(augmented.tolist()) # Build condition and update condition = self._query_builder.build_sample_filter(list(all_samples_to_include)) self._store.update_by_condition(condition, {"excluded": False, "exclusion_reason": None}) return len(all_samples_to_include)
[docs] def get_excluded_samples(self, selector: Optional[Selector] = None) -> pl.DataFrame: """ Get DataFrame of excluded samples with their exclusion reasons. Args: selector: Optional filter criteria to narrow down the query. If None, returns all excluded samples. Returns: pl.DataFrame: DataFrame containing excluded samples with columns: sample, origin, partition, group, branch, exclusion_reason. Examples: >>> indexer = Indexer() >>> indexer.add_samples(5, partition="train") >>> indexer.mark_excluded([0, 1], reason="outlier") >>> >>> # Get all excluded samples >>> excluded_df = indexer.get_excluded_samples() >>> print(excluded_df) >>> >>> # Get excluded samples from train partition only >>> train_excluded = indexer.get_excluded_samples({"partition": "train"}) Note: Returns a Polars DataFrame for efficient processing. Use .to_pandas() if pandas DataFrame is needed. """ # Start with excluded=True filter condition = pl.col("excluded") == True # noqa: E712 # Add selector conditions if provided if selector: selector_condition = self._build_filter_condition(selector) condition = condition & selector_condition return self._store.query(condition).select([ "sample", "origin", "partition", "group", "branch", "augmentation", "exclusion_reason" ])
[docs] def get_exclusion_summary(self) -> Dict[str, Any]: """ Get summary statistics of exclusions by reason. Returns: Dict[str, Any]: Dictionary containing: - total_excluded: Total number of excluded samples - total_samples: Total number of samples in indexer - exclusion_rate: Ratio of excluded to total samples - by_reason: Dict mapping reason strings to counts - by_partition: Dict mapping partition names to excluded counts Examples: >>> indexer = Indexer() >>> indexer.add_samples(10, partition="train") >>> indexer.mark_excluded([0, 1], reason="outlier") >>> indexer.mark_excluded([2], reason="low_quality") >>> >>> summary = indexer.get_exclusion_summary() >>> print(summary) >>> # { >>> # 'total_excluded': 3, >>> # 'total_samples': 10, >>> # 'exclusion_rate': 0.3, >>> # 'by_reason': {'outlier': 2, 'low_quality': 1}, >>> # 'by_partition': {'train': 3} >>> # } """ df = self._store.df total_samples = len(df) # Filter to excluded samples excluded_df = df.filter(pl.col("excluded") == True) # noqa: E712 total_excluded = len(excluded_df) # Group by reason by_reason = {} if total_excluded > 0: reason_counts = excluded_df.group_by("exclusion_reason").agg( pl.len().alias("count") ).to_dicts() for row in reason_counts: reason = row["exclusion_reason"] if row["exclusion_reason"] else "unspecified" by_reason[reason] = row["count"] # Group by partition by_partition = {} if total_excluded > 0: partition_counts = excluded_df.group_by("partition").agg( pl.len().alias("count") ).to_dicts() for row in partition_counts: by_partition[row["partition"]] = row["count"] return { "total_excluded": total_excluded, "total_samples": total_samples, "exclusion_rate": total_excluded / total_samples if total_samples > 0 else 0.0, "by_reason": by_reason, "by_partition": by_partition, }
[docs] def reset_exclusions(self, selector: Optional[Selector] = None) -> int: """ Remove all exclusion flags matching the selector. This is a convenience method equivalent to calling mark_included() on all excluded samples matching the selector. Args: selector: Optional filter criteria. If None, resets ALL exclusions. Returns: int: Number of samples reset. Examples: >>> # Reset all exclusions >>> n_reset = indexer.reset_exclusions() >>> >>> # Reset only train partition exclusions >>> n_reset = indexer.reset_exclusions({"partition": "train"}) """ # Build condition for excluded samples condition = pl.col("excluded") == True # noqa: E712 # Add selector conditions if provided if selector: selector_condition = self._build_filter_condition(selector) condition = condition & selector_condition # Count before update count = len(self._store.query(condition)) # Update to include self._store.update_by_condition(condition, {"excluded": False, "exclusion_reason": None}) return count