Source code for nirs4all.controllers.data.sample_partitioner

"""
Sample Partitioner Controller for sample-based branching.

This controller partitions the dataset into multiple branches based on a
sample filter (e.g., outlier detection). Unlike OutlierExcluderController
which excludes samples from training, this controller creates separate
branches where each branch contains a different subset of samples.

For example, with Y-outlier detection:
    - Branch "outliers": Contains ONLY the outlier samples
    - Branch "inliers": Contains ONLY the non-outlier samples

This enables training separate models for different data subsets and
comparing their performance.

Example:
    >>> pipeline = [
    ...     ShuffleSplit(n_splits=5),
    ...     {"branch": {
    ...         "by": "sample_partitioner",
    ...         "filter": {"method": "y_outlier", "threshold": 3.0},
    ...     }},
    ...     PLSRegression(n_components=10),
    ... ]
"""

import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING

import numpy as np

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.operators.filters.base import SampleFilter
from nirs4all.operators.filters.x_outlier import XOutlierFilter
from nirs4all.operators.filters.y_outlier import YOutlierFilter
from nirs4all.pipeline.execution.result import StepOutput

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
    from nirs4all.pipeline.steps.parser import ParsedStep


def _create_partition_filter(filter_config: Dict[str, Any]) -> SampleFilter:
    """
    Create a sample filter from a filter configuration.

    Args:
        filter_config: Dict with 'method' and method-specific parameters.
                      Supported methods:
                      - "y_outlier": Y value outlier detection (IQR-based)
                      - "x_outlier": X value outlier detection
                      - "isolation_forest": Isolation Forest on X
                      - "mahalanobis": Mahalanobis distance on X
                      - "lof": Local Outlier Factor on X

    Returns:
        SampleFilter: Configured filter instance.

    Raises:
        ValueError: If method is not recognized.
    """
    method = filter_config.get("method", "y_outlier")

    # Y-based outlier detection
    if method == "y_outlier":
        threshold = filter_config.get("threshold", 1.5)
        return YOutlierFilter(method="iqr", threshold=threshold)

    # X-based outlier detection methods
    x_methods = {
        "x_outlier": "isolation_forest",
        "isolation_forest": "isolation_forest",
        "mahalanobis": "mahalanobis",
        "robust_mahalanobis": "robust_mahalanobis",
        "lof": "lof",
        "leverage": "pca_leverage",
        "pca_residual": "pca_residual",
    }

    if method in x_methods:
        filter_method = x_methods[method]
        filter_kwargs = {
            "method": filter_method,
            "reason": f"partition_{method}",
        }

        if "contamination" in filter_config:
            filter_kwargs["contamination"] = filter_config["contamination"]
        if "threshold" in filter_config:
            filter_kwargs["threshold"] = filter_config["threshold"]
        if "n_components" in filter_config:
            filter_kwargs["n_components"] = filter_config["n_components"]
        if "random_state" in filter_config:
            filter_kwargs["random_state"] = filter_config["random_state"]

        return XOutlierFilter(**filter_kwargs)

    valid_methods = ["y_outlier"] + list(x_methods.keys())
    raise ValueError(
        f"Unknown partition method '{method}'. "
        f"Valid methods: {valid_methods}"
    )


def _generate_branch_names(filter_config: Dict[str, Any]) -> Tuple[str, str]:
    """
    Generate branch names for the two partitions.

    Args:
        filter_config: Filter configuration dict.

    Returns:
        Tuple of (outliers_name, inliers_name).
    """
    method = filter_config.get("method", "y_outlier")

    # Allow custom names
    if "branch_names" in filter_config:
        names = filter_config["branch_names"]
        if len(names) >= 2:
            return (names[0], names[1])

    # Default names based on method
    if method == "y_outlier":
        return ("y_outliers", "y_inliers")
    elif method in ("x_outlier", "isolation_forest", "mahalanobis", "lof", "leverage", "pca_residual"):
        return ("x_outliers", "x_inliers")
    else:
        return ("outliers", "inliers")


[docs] @register_controller class SamplePartitionerController(OperatorController): """ Controller for sample-based branching via partitioning. This controller creates two branches by partitioning samples based on a filter (e.g., outlier detection). Each branch contains a different subset of samples: - "outliers" branch: samples where filter returns False (outliers) - "inliers" branch: samples where filter returns True (non-outliers) Unlike OutlierExcluderController which only excludes from training, this controller truly partitions the samples so each branch trains and predicts only on its subset. Key behaviors: - Each branch contains a disjoint subset of samples - Samples are partitioned, not excluded - Models train and predict only on their partition - Supports Y-outlier and X-outlier detection methods Attributes: priority: Controller priority (set to 3 to run before outlier excluder). """ priority = 3 # Higher priority than OutlierExcluderController (4)
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """ Check if the step matches the sample_partitioner branch pattern. Matches: {"branch": {"by": "sample_partitioner", "filter": {...}}} Args: step: Original step configuration operator: Deserialized operator keyword: Step keyword Returns: True if this is a sample_partitioner branch definition. """ if keyword != "branch": return False if isinstance(step, dict) and "branch" in step: branch_def = step["branch"] if isinstance(branch_def, dict) and branch_def.get("by") == "sample_partitioner": return True return False
[docs] @classmethod def use_multi_source(cls) -> bool: """Sample partitioner operates on dataset level.""" return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """ Sample partitioner should execute in prediction mode. In prediction mode, we need to reconstruct the branch contexts and apply the same sample partitioning. """ return True
[docs] def execute( self, step_info: "ParsedStep", dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", loaded_binaries: Optional[List[Tuple[str, Any]]] = None, prediction_store: Optional[Any] = None ) -> Tuple["ExecutionContext", StepOutput]: """ Execute the sample partitioner branch step. Creates two branches: one for outliers and one for inliers. Each branch contains only its subset of samples. Args: step_info: Parsed step containing branch definitions dataset: Dataset to operate on context: Pipeline execution context runtime_context: Runtime infrastructure context source: Data source index mode: Execution mode ("train" or "predict") loaded_binaries: Pre-loaded binary objects for prediction mode prediction_store: External prediction store for model predictions Returns: Tuple of (updated_context, StepOutput with collected artifacts) """ # Parse filter config from step branch_def = step_info.original_step.get("branch", {}) filter_config = branch_def.get("filter", {"method": "y_outlier"}) logger.info("Creating sample partitioner branches") # Store initial context as snapshot initial_context = context.copy() initial_processing = copy.deepcopy(context.selector.processing) # Snapshot dataset features initial_features = self._snapshot_features(dataset) # Get training sample indices train_context = context.with_partition("train") train_selector = train_context.selector.copy() train_selector.include_augmented = False train_sample_indices = dataset._indexer.x_indices( train_selector, include_augmented=False, include_excluded=False ) if len(train_sample_indices) == 0: logger.warning("No training samples found, skipping partitioner") return context, StepOutput() # Get X and Y data for filter train_X = dataset.x(train_selector, layout="2d", concat_source=True) train_y = dataset.y(train_selector) # Create and fit filter filter_obj = _create_partition_filter(filter_config) filter_obj.fit(train_X, train_y) # Get mask: True = inlier (keep), False = outlier (remove) mask = filter_obj.get_mask(train_X, train_y) # Compute partition indices outlier_indices = train_sample_indices[~mask] inlier_indices = train_sample_indices[mask] n_outliers = len(outlier_indices) n_inliers = len(inliers_indices := inlier_indices) # alias for clarity n_total = len(train_sample_indices) logger.info(f" Partition: {n_outliers} outliers, {n_inliers} inliers " f"({100 * n_outliers / n_total:.1f}% / {100 * n_inliers / n_total:.1f}%)") # Generate branch names outliers_name, inliers_name = _generate_branch_names(filter_config) # In predict/explain mode, filter to target branch if specified target_branch_id = None if mode in ("predict", "explain") and hasattr(runtime_context, 'target_model') and runtime_context.target_model: target_branch_id = runtime_context.target_model.get("branch_id") # Create branch contexts branch_contexts: List[Dict[str, Any]] = [] all_artifacts = [] # Branch 0: Outliers if target_branch_id is None or target_branch_id == 0: branch_context_outliers = initial_context.copy() branch_context_outliers.selector = branch_context_outliers.selector.with_branch( branch_id=0, branch_name=outliers_name ) branch_context_outliers.selector.processing = copy.deepcopy(initial_processing) # Store sample partition info branch_context_outliers.custom["sample_partition"] = { "sample_indices": outlier_indices.tolist() if isinstance(outlier_indices, np.ndarray) else list(outlier_indices), "partition_type": "outliers", "n_samples": n_outliers, "filter_config": filter_config, } branch_contexts.append({ "branch_id": 0, "name": outliers_name, "context": branch_context_outliers, "partition_info": { "type": "outliers", "n_samples": n_outliers, "sample_indices": outlier_indices.tolist() if isinstance(outlier_indices, np.ndarray) else list(outlier_indices), } }) logger.info(f" Branch 0: {outliers_name} ({n_outliers} samples)") # Branch 1: Inliers if target_branch_id is None or target_branch_id == 1: self._restore_features(dataset, initial_features) branch_context_inliers = initial_context.copy() branch_context_inliers.selector = branch_context_inliers.selector.with_branch( branch_id=1, branch_name=inliers_name ) branch_context_inliers.selector.processing = copy.deepcopy(initial_processing) # Store sample partition info branch_context_inliers.custom["sample_partition"] = { "sample_indices": inlier_indices.tolist() if isinstance(inlier_indices, np.ndarray) else list(inlier_indices), "partition_type": "inliers", "n_samples": n_inliers, "filter_config": filter_config, } branch_contexts.append({ "branch_id": 1 if target_branch_id is None else target_branch_id, "name": inliers_name, "context": branch_context_inliers, "partition_info": { "type": "inliers", "n_samples": n_inliers, "sample_indices": inlier_indices.tolist() if isinstance(inlier_indices, np.ndarray) else list(inlier_indices), } }) logger.info(f" Branch 1: {inliers_name} ({n_inliers} samples)") # Persist filter for prediction mode if mode == "train" and runtime_context.saver is not None: artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name=f"partition_filter_{runtime_context.next_op()}", obj=filter_obj, format_hint='sklearn', ) all_artifacts.append(artifact) # Handle nested branching (multiply with existing branches) existing_branches = context.custom.get("branch_contexts", []) if existing_branches: new_branch_contexts = self._multiply_branch_contexts( existing_branches, branch_contexts ) else: new_branch_contexts = branch_contexts # Update result context result_context = context.copy() result_context.custom["branch_contexts"] = new_branch_contexts result_context.custom["in_branch_mode"] = True result_context.custom["sample_partitioner_active"] = True logger.success(f"Sample partitioner completed with {len(new_branch_contexts)} branch(es)") return result_context, StepOutput( artifacts=all_artifacts, metadata={ "branch_count": len(new_branch_contexts), "sample_partitioner": True, "n_outliers": n_outliers, "n_inliers": n_inliers, } )
def _snapshot_features(self, dataset: "SpectroDataset") -> List[Any]: """Create a deep copy of dataset features for branch isolation.""" return copy.deepcopy(dataset._features.sources) def _restore_features( self, dataset: "SpectroDataset", snapshot: List[Any] ) -> None: """Restore dataset features from snapshot.""" dataset._features.sources = copy.deepcopy(snapshot) def _multiply_branch_contexts( self, existing: List[Dict[str, Any]], new: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ Multiply existing branch contexts with new ones for nested branching. Creates Cartesian product: each existing branch × each new branch. Args: existing: List of existing branch context dicts new: List of new branch context dicts Returns: Combined list of branch contexts """ result = [] flattened_id = 0 for parent in existing: parent_id = parent["branch_id"] parent_name = parent["name"] for child in new: child_id = child["branch_id"] child_name = child["name"] child_context = child["context"] # Create combined context combined_context = child_context.copy() combined_context.selector.branch_id = flattened_id combined_context.selector.branch_name = f"{parent_name}_{child_name}" result.append({ "branch_id": flattened_id, "name": f"{parent_name}_{child_name}", "context": combined_context, "parent_branch_id": parent_id, "child_branch_id": child_id, "partition_info": child.get("partition_info", {}), }) flattened_id += 1 return result