Source code for nirs4all.controllers.data.outlier_excluder

"""
Outlier Excluder Controller for sample-based branching.

This controller enables creating multiple branches with different outlier
exclusion strategies. Each branch trains on a different subset of samples
based on the outlier detection method applied.

This is useful for comparing how different outlier handling approaches
affect model performance without creating separate pipeline runs.

Example:
    >>> pipeline = [
    ...     ShuffleSplit(n_splits=5),
    ...     {"branch": {
    ...         "by": "outlier_excluder",
    ...         "strategies": [
    ...             None,  # No exclusion (baseline)
    ...             {"method": "isolation_forest", "contamination": 0.05},
    ...             {"method": "mahalanobis", "threshold": 3.0},
    ...             {"method": "leverage", "threshold": 2.0},
    ...             {"method": "lof", "contamination": 0.05},
    ...         ],
    ...     }},
    ...     PLSRegression(n_components=10),
    ... ]
"""

import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING

import numpy as np

from nirs4all.controllers.controller import OperatorController
from nirs4all.controllers.registry import register_controller
from nirs4all.core.logging import get_logger
from nirs4all.operators.filters.base import SampleFilter
from nirs4all.operators.filters.x_outlier import XOutlierFilter
from nirs4all.pipeline.execution.result import StepOutput

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nirs4all.data.dataset import SpectroDataset
    from nirs4all.pipeline.config.context import ExecutionContext, RuntimeContext
    from nirs4all.pipeline.steps.parser import ParsedStep


[docs] @dataclass class OutlierExclusionResult: """Result of applying an outlier exclusion strategy.""" strategy_name: str n_total: int n_excluded: int excluded_indices: List[int] exclusion_mask: np.ndarray # True = keep, False = exclude
def _create_outlier_filter(strategy: Dict[str, Any]) -> SampleFilter: """ Create an outlier filter from a strategy specification. Supports methods from XOutlierFilter plus additional high-level shortcuts. Args: strategy: Strategy dict with 'method' and method-specific parameters. Supported methods: - "isolation_forest": Isolation Forest anomaly detection - "mahalanobis": Mahalanobis distance from centroid - "robust_mahalanobis": Robust Mahalanobis (MinCovDet) - "leverage": High leverage points (alias for pca_leverage) - "lof": Local Outlier Factor - "pca_residual": Q-statistic from PCA reconstruction Returns: SampleFilter: Configured outlier filter instance. Raises: ValueError: If method is not recognized. """ method = strategy.get("method", "isolation_forest") # Map high-level method names to XOutlierFilter methods method_mapping = { "isolation_forest": "isolation_forest", "mahalanobis": "mahalanobis", "robust_mahalanobis": "robust_mahalanobis", "leverage": "pca_leverage", # Alias for high leverage detection "lof": "lof", "pca_residual": "pca_residual", "pca_leverage": "pca_leverage", } if method not in method_mapping: valid_methods = list(method_mapping.keys()) raise ValueError( f"Unknown outlier method '{method}'. " f"Valid methods: {valid_methods}" ) filter_method = method_mapping[method] # Build filter kwargs from strategy filter_kwargs = { "method": filter_method, "reason": f"outlier_{method}", } # Map strategy parameters to filter parameters if "contamination" in strategy: filter_kwargs["contamination"] = strategy["contamination"] if "threshold" in strategy: filter_kwargs["threshold"] = strategy["threshold"] if "n_components" in strategy: filter_kwargs["n_components"] = strategy["n_components"] if "random_state" in strategy: filter_kwargs["random_state"] = strategy["random_state"] return XOutlierFilter(**filter_kwargs) def _strategy_to_name(strategy: Optional[Dict[str, Any]], idx: int) -> str: """ Generate a human-readable name for an outlier strategy. Args: strategy: Strategy dict or None for baseline. idx: Strategy index for fallback naming. Returns: str: Branch name like "baseline", "if_0.05", "mahal_3.0", etc. """ if strategy is None: return "baseline" method = strategy.get("method", "unknown") # Create abbreviated name method_abbrev = { "isolation_forest": "if", "mahalanobis": "mahal", "robust_mahalanobis": "rmahal", "leverage": "lev", "pca_leverage": "lev", "lof": "lof", "pca_residual": "pca_q", } abbrev = method_abbrev.get(method, method[:4]) # Add threshold/contamination to name if "contamination" in strategy: return f"{abbrev}_{strategy['contamination']}" elif "threshold" in strategy: return f"{abbrev}_{strategy['threshold']}" else: return f"{abbrev}_{idx}"
[docs] @register_controller class OutlierExcluderController(OperatorController): """ Controller for sample-based branching with outlier exclusion strategies. This controller creates multiple branches, each with a different outlier exclusion strategy. Samples identified as outliers are excluded from training in that branch, but predictions still cover all samples. Key behaviors: - Each branch applies a different outlier detection method - Outlier detection runs on training data only - Exclusion is per-branch (tracked in context, not in indexer) - Predictions include exclusion metadata for analysis - Branch 0 with None strategy serves as baseline Attributes: priority: Controller priority (set to 4 to run before regular branch controller). """ priority = 4 # Higher priority than BranchController (5)
[docs] @classmethod def matches(cls, step: Any, operator: Any, keyword: str) -> bool: """ Check if the step matches the outlier excluder branch pattern. Matches: {"branch": {"by": "outlier_excluder", "strategies": [...]}} Args: step: Original step configuration operator: Deserialized operator keyword: Step keyword Returns: True if this is an outlier_excluder branch definition. """ if keyword != "branch": return False # Check if it's the outlier_excluder pattern if isinstance(step, dict) and "branch" in step: branch_def = step["branch"] if isinstance(branch_def, dict) and branch_def.get("by") == "outlier_excluder": return True return False
[docs] @classmethod def use_multi_source(cls) -> bool: """Outlier excluder operates on dataset level.""" return True
[docs] @classmethod def supports_prediction_mode(cls) -> bool: """ Outlier excluder should execute in prediction mode. In prediction mode, we need to reconstruct the branch contexts but NOT apply sample exclusion (we predict on all samples). """ return True
[docs] def execute( self, step_info: "ParsedStep", dataset: "SpectroDataset", context: "ExecutionContext", runtime_context: "RuntimeContext", source: int = -1, mode: str = "train", loaded_binaries: Optional[List[Tuple[str, Any]]] = None, prediction_store: Optional[Any] = None ) -> Tuple["ExecutionContext", StepOutput]: """ Execute the outlier excluder branch step. Creates branches for each outlier exclusion strategy. In train mode, applies outlier detection and marks exclusions. In predict mode, reconstructs branch contexts without applying exclusions. Args: step_info: Parsed step containing branch definitions dataset: Dataset to operate on context: Pipeline execution context runtime_context: Runtime infrastructure context source: Data source index mode: Execution mode ("train" or "predict") loaded_binaries: Pre-loaded binary objects for prediction mode prediction_store: External prediction store for model predictions Returns: Tuple of (updated_context, StepOutput with collected artifacts) """ # Parse strategies from step branch_def = step_info.original_step.get("branch", {}) strategies = branch_def.get("strategies", []) if not strategies: logger.warning("No outlier strategies defined, skipping outlier excluder") return context, StepOutput() n_strategies = len(strategies) # In predict/explain mode, filter to target branch if specified target_branch_id = None if mode in ("predict", "explain") and hasattr(runtime_context, 'target_model') and runtime_context.target_model: target_branch_id = runtime_context.target_model.get("branch_id") if target_branch_id is not None: if target_branch_id >= n_strategies: raise ValueError( f"Target branch_id={target_branch_id} not found. " f"Pipeline has {n_strategies} strategies (0-{n_strategies-1})." ) strategies = [strategies[target_branch_id]] logger.info(f"Predict mode: using outlier strategy {target_branch_id}") logger.info(f"Creating {len(strategies)} outlier excluder branches") # Store initial context as snapshot initial_context = context.copy() initial_processing = copy.deepcopy(context.selector.processing) # Snapshot dataset features initial_features = self._snapshot_features(dataset) # Get training data for outlier detection train_X = None train_y = None train_sample_indices = None if mode == "train": train_context = context.with_partition("train") train_selector = train_context.selector.copy() train_selector.include_augmented = False train_sample_indices = dataset._indexer.x_indices( train_selector, include_augmented=False, include_excluded=False ) if len(train_sample_indices) > 0: train_X = dataset.x(train_selector, layout="2d", concat_source=True) train_y = dataset.y(train_selector) # Process each strategy branch_contexts: List[Dict[str, Any]] = [] all_artifacts = [] initial_op_count = runtime_context.operation_count for idx, strategy in enumerate(strategies): # Determine branch_id (preserve original if in predict mode with filter) if target_branch_id is not None: branch_id = target_branch_id runtime_context.operation_count = initial_op_count + branch_id else: branch_id = idx # Generate branch name from strategy branch_name = _strategy_to_name(strategy, branch_id) logger.info(f" Branch {branch_id}: {branch_name}") # Create isolated context for this branch branch_context = initial_context.copy() branch_context.selector = branch_context.selector.with_branch( branch_id=branch_id, branch_name=branch_name ) branch_context.selector.processing = copy.deepcopy(initial_processing) # Restore dataset features self._restore_features(dataset, initial_features) # Apply outlier exclusion in train mode only exclusion_info = { "strategy": strategy, "n_excluded": 0, "exclusion_rate": 0.0, "excluded_indices": [], } if mode == "train" and strategy is not None and train_X is not None and len(train_X) > 0: try: result = self._apply_outlier_strategy( strategy=strategy, X=train_X, y=train_y, sample_indices=train_sample_indices, branch_id=branch_id, branch_name=branch_name, ) exclusion_info.update({ "n_excluded": result.n_excluded, "exclusion_rate": result.n_excluded / result.n_total if result.n_total > 0 else 0.0, "excluded_indices": result.excluded_indices, }) # Store exclusion mask in branch context for use by models branch_context.custom["outlier_exclusion"] = { "mask": result.exclusion_mask, # True = keep "sample_indices": train_sample_indices.tolist() if isinstance(train_sample_indices, np.ndarray) else train_sample_indices, "n_excluded": result.n_excluded, "strategy": strategy, } logger.info(f" Excluded {result.n_excluded}/{result.n_total} samples " f"({100 * exclusion_info['exclusion_rate']:.1f}%)") # Persist the fitted filter for prediction mode reference if runtime_context.saver is not None: filter_obj = _create_outlier_filter(strategy) filter_obj.fit(train_X, train_y) artifact = runtime_context.saver.persist_artifact( step_number=runtime_context.step_number, name=f"outlier_filter_b{branch_id}_{runtime_context.next_op()}", obj=filter_obj, format_hint='sklearn', branch_id=branch_id, branch_name=branch_name ) all_artifacts.append(artifact) except Exception as e: logger.warning(f"Outlier detection failed: {e}") # Continue with no exclusion on failure elif strategy is None: logger.info(" No exclusion (baseline)") # Store exclusion info in context for downstream use branch_context.custom["exclusion_info"] = exclusion_info # Store branch context branch_contexts.append({ "branch_id": branch_id, "name": branch_name, "context": branch_context, "exclusion_info": exclusion_info, }) logger.success(f" Branch {branch_id} ({branch_name}) completed") # Handle nested branching (multiply with existing branches) existing_branches = context.custom.get("branch_contexts", []) if existing_branches: new_branch_contexts = self._multiply_branch_contexts( existing_branches, branch_contexts ) else: new_branch_contexts = branch_contexts # Update result context result_context = context.copy() result_context.custom["branch_contexts"] = new_branch_contexts result_context.custom["in_branch_mode"] = True result_context.custom["outlier_excluder_active"] = True logger.success(f"Outlier excluder completed with {len(new_branch_contexts)} branch(es)") return result_context, StepOutput( artifacts=all_artifacts, metadata={ "branch_count": len(new_branch_contexts), "outlier_excluder": True, } )
def _apply_outlier_strategy( self, strategy: Dict[str, Any], X: np.ndarray, y: Optional[np.ndarray], sample_indices: np.ndarray, branch_id: int, branch_name: str, ) -> OutlierExclusionResult: """ Apply an outlier detection strategy and return exclusion results. Args: strategy: Strategy configuration dict X: Feature array (n_samples, n_features) y: Target array (n_samples,) or None sample_indices: Original sample indices in dataset branch_id: Branch identifier branch_name: Branch name Returns: OutlierExclusionResult with exclusion details """ n_samples = len(X) # Create and fit filter filter_obj = _create_outlier_filter(strategy) filter_obj.fit(X, y) # Get mask (True = keep, False = exclude) mask = filter_obj.get_mask(X, y) # Get excluded indices excluded_mask = ~mask excluded_sample_indices = sample_indices[excluded_mask].tolist() if isinstance(sample_indices, np.ndarray) else [ sample_indices[i] for i, ex in enumerate(excluded_mask) if ex ] return OutlierExclusionResult( strategy_name=branch_name, n_total=n_samples, n_excluded=int(np.sum(excluded_mask)), excluded_indices=excluded_sample_indices, exclusion_mask=mask, ) def _snapshot_features(self, dataset: "SpectroDataset") -> List[Any]: """Create a deep copy of dataset features for branch isolation.""" return copy.deepcopy(dataset._features.sources) def _restore_features( self, dataset: "SpectroDataset", snapshot: List[Any] ) -> None: """Restore dataset features from snapshot.""" dataset._features.sources = copy.deepcopy(snapshot) def _multiply_branch_contexts( self, existing: List[Dict[str, Any]], new: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ Multiply existing branch contexts with new ones for nested branching. Creates Cartesian product: each existing branch × each new branch. Args: existing: List of existing branch context dicts new: List of new branch context dicts Returns: Combined list of branch contexts """ result = [] flattened_id = 0 for parent in existing: parent_id = parent["branch_id"] parent_name = parent["name"] for child in new: child_id = child["branch_id"] child_name = child["name"] child_context = child["context"] # Create combined context combined_context = child_context.copy() combined_context.selector.branch_id = flattened_id combined_context.selector.branch_name = f"{parent_name}_{child_name}" # Merge exclusion info parent_exclusion = parent.get("exclusion_info", {}) child_exclusion = child.get("exclusion_info", {}) result.append({ "branch_id": flattened_id, "name": f"{parent_name}_{child_name}", "context": combined_context, "parent_branch_id": parent_id, "child_branch_id": child_id, "exclusion_info": { "parent": parent_exclusion, "child": child_exclusion, } }) flattened_id += 1 return result