Source code for nirs4all.pipeline.config._generator.strategies.cartesian_strategy

"""Cartesian strategy for staged pipeline expansion with selection.

This module handles _cartesian_ nodes that first compute the Cartesian product
of nested stages (each with _or_ choices), then apply pick/arrange selection
on the resulting complete pipelines.

This is the key pattern for preprocessing pipeline generation:
- Define stages (scatter, smooth, derivative, etc.)
- Each stage has multiple options via _or_
- Generate all combinations (Cartesian product)
- Then pick/arrange from the complete pipelines

Syntax:
    {"_cartesian_": [stage1, stage2, ...], "pick": N}
    {"_cartesian_": [stage1, stage2, ...], "arrange": N}

Examples:
    # Generate all pipeline combinations, then pick 2
    {"_cartesian_": [
        {"_or_": ["MSC", "SNV", "EMSC"]},
        {"_or_": ["SavGol", "Gaussian", None]},
        {"_or_": [None, "Deriv1", "Deriv2"]}
    ], "pick": 2}

    -> Generates all 27 pipelines (3×3×3), then picks 2-combinations
    -> Result: [["MSC", "SavGol", None], ["SNV", "Gaussian", "Deriv1"]], ...

    # Pick 1-3 complete pipelines
    {"_cartesian_": [...], "pick": (1, 3), "count": 20}
"""

from itertools import combinations, permutations
from math import comb, factorial
from typing import Any, FrozenSet, List, Optional, Tuple

from .base import ExpansionStrategy, GeneratorNode, ExpandedResult, SizeSpec
from .registry import register_strategy
from ..keywords import (
    COUNT_KEYWORD,
    SEED_KEYWORD,
    PICK_KEYWORD,
    ARRANGE_KEYWORD,
    MUTEX_KEYWORD,
    REQUIRES_KEYWORD,
    EXCLUDE_KEYWORD,
    TAGS_KEYWORD,
    METADATA_KEYWORD,
)
from ..utils.sampling import sample_with_seed

# Define the keyword
CARTESIAN_KEYWORD: str = "_cartesian_"

# Valid keys for a pure cartesian node
PURE_CARTESIAN_KEYS: FrozenSet[str] = frozenset({
    CARTESIAN_KEYWORD,
    PICK_KEYWORD,
    ARRANGE_KEYWORD,
    COUNT_KEYWORD,
    SEED_KEYWORD,
    MUTEX_KEYWORD,
    REQUIRES_KEYWORD,
    EXCLUDE_KEYWORD,
    TAGS_KEYWORD,
    METADATA_KEYWORD,
})


[docs] @register_strategy class CartesianStrategy(ExpansionStrategy): """Strategy for handling _cartesian_ nodes. Generates the Cartesian product of all stages first (each stage being an _or_ node or list of options), then applies pick or arrange selection to the complete pipelines. This differs from _grid_ which produces dicts. _cartesian_ produces lists (ordered stages) which is ideal for preprocessing pipelines. Supported formats: - Array of stages: [stage1, stage2, ...] - With pick: Select N combinations of complete pipelines - With arrange: Select N permutations of complete pipelines - With count: Limit number of results - With constraints: Filter invalid combinations Attributes: keywords: {_cartesian_, pick, arrange, count, ...} priority: 35 (high priority, checked before grid) """ keywords: FrozenSet[str] = PURE_CARTESIAN_KEYS priority: int = 35 # High priority
[docs] @classmethod def handles(cls, node: GeneratorNode) -> bool: """Check if node is a pure cartesian node. Args: node: Dictionary node to check. Returns: True if node contains _cartesian_ and only cartesian-related keys. """ if not isinstance(node, dict): return False return CARTESIAN_KEYWORD in node and set(node.keys()).issubset(PURE_CARTESIAN_KEYS)
[docs] def expand( self, node: GeneratorNode, seed: Optional[int] = None, expand_nested: Optional[callable] = None ) -> ExpandedResult: """Expand a cartesian node to list of pipeline combinations. The process: 1. Expand each stage to get its options 2. Compute Cartesian product of all stages -> complete pipelines 3. If pick/arrange specified, select from complete pipelines 4. Apply constraints if specified 5. Apply count limit if specified Args: node: Cartesian specification node. seed: Optional seed for random sampling when count is used. expand_nested: Callback to expand nested generator nodes. Returns: List of pipeline combinations. Examples: >>> strategy.expand({ ... "_cartesian_": [ ... {"_or_": ["A", "B"]}, ... {"_or_": ["X", "Y"]} ... ], ... "pick": 2 ... }) [[["A", "X"], ["A", "Y"]], [["A", "X"], ["B", "X"]], ...] """ stages = node[CARTESIAN_KEYWORD] pick = node.get(PICK_KEYWORD) arrange = node.get(ARRANGE_KEYWORD) count = node.get(COUNT_KEYWORD) node_seed = node.get(SEED_KEYWORD, seed) # Extract constraints mutex_groups = node.get(MUTEX_KEYWORD, []) requires_groups = node.get(REQUIRES_KEYWORD, []) exclude_combos = node.get(EXCLUDE_KEYWORD, []) if not isinstance(stages, list): raise ValueError( f"_cartesian_ must be a list of stages, got {type(stages).__name__}" ) # Handle empty stages if not stages: return [[]] # Step 1: Expand each stage to get its options expanded_stages = [] for stage in stages: if expand_nested and isinstance(stage, (dict, list)): stage_options = expand_nested(stage) else: stage_options = [stage] expanded_stages.append(stage_options) # Step 2: Compute Cartesian product -> complete pipelines from itertools import product as cartesian_product all_pipelines = [list(combo) for combo in cartesian_product(*expanded_stages)] # Step 3: Apply pick or arrange selection if pick is not None: result = self._apply_pick(all_pipelines, pick) elif arrange is not None: result = self._apply_arrange(all_pipelines, arrange) else: # No selection - return all pipelines as-is result = all_pipelines # Step 4: Apply constraints if specified if mutex_groups or requires_groups or exclude_combos: result = self._apply_constraints( result, mutex_groups, requires_groups, exclude_combos ) # Step 5: Apply count limit if count is not None and len(result) > count: result = sample_with_seed(result, count, seed=node_seed) return result
def _apply_pick( self, pipelines: List[List[Any]], pick_spec: SizeSpec ) -> ExpandedResult: """Apply pick (combinations) to the list of pipelines. Args: pipelines: List of complete pipelines. pick_spec: Size specification (int or tuple). Returns: List of pipeline combinations. """ from_size, to_size = self._normalize_spec(pick_spec) result = [] for size in range(from_size, to_size + 1): if size > len(pipelines): continue if size == 0: result.append([]) continue for combo in combinations(pipelines, size): result.append(list(combo)) return result def _apply_arrange( self, pipelines: List[List[Any]], arrange_spec: SizeSpec ) -> ExpandedResult: """Apply arrange (permutations) to the list of pipelines. Args: pipelines: List of complete pipelines. arrange_spec: Size specification (int or tuple). Returns: List of pipeline permutations. """ from_size, to_size = self._normalize_spec(arrange_spec) result = [] for size in range(from_size, to_size + 1): if size > len(pipelines): continue if size == 0: result.append([]) continue for perm in permutations(pipelines, size): result.append(list(perm)) return result def _normalize_spec(self, spec: SizeSpec) -> Tuple[int, int]: """Normalize size specification to (from, to) tuple.""" if isinstance(spec, int): return (spec, spec) elif isinstance(spec, (tuple, list)) and len(spec) == 2: return (spec[0], spec[1]) else: raise ValueError(f"Invalid size spec: {spec}. Must be int or (from, to).") def _apply_constraints( self, results: ExpandedResult, mutex_groups: List[List[Any]], requires_groups: List[List[Any]], exclude_combos: List[List[Any]] ) -> ExpandedResult: """Apply constraint filters to expanded results. Note: For _cartesian_, constraints apply to the selected pipelines, not the items within pipelines. """ from ..constraints import apply_all_constraints if not results: return results return apply_all_constraints( results, mutex_groups=mutex_groups, requires_groups=requires_groups, exclude_combos=exclude_combos )
[docs] def count(self, node: GeneratorNode, count_nested: Optional[callable] = None) -> int: """Count cartesian combinations without generating them. Args: node: Cartesian specification node. count_nested: Callback to count nested nodes. Returns: Number of pipeline combinations. """ stages = node[CARTESIAN_KEYWORD] pick = node.get(PICK_KEYWORD) arrange = node.get(ARRANGE_KEYWORD) count_limit = node.get(COUNT_KEYWORD) if not isinstance(stages, list): return 0 if not stages: return 1 # Count total pipelines (Cartesian product) total_pipelines = 1 for stage in stages: if count_nested and isinstance(stage, (dict, list)): stage_count = count_nested(stage) else: stage_count = 1 total_pipelines *= stage_count # Apply pick/arrange selection count if pick is not None: total = self._count_pick(total_pipelines, pick) elif arrange is not None: total = self._count_arrange(total_pipelines, arrange) else: total = total_pipelines # Apply count limit if count_limit is not None: return min(count_limit, total) return total
def _count_pick(self, n: int, pick_spec: SizeSpec) -> int: """Count pick combinations.""" from_size, to_size = self._normalize_spec(pick_spec) total = 0 for size in range(from_size, to_size + 1): if size <= n: total += comb(n, size) return total def _count_arrange(self, n: int, arrange_spec: SizeSpec) -> int: """Count arrange permutations.""" from_size, to_size = self._normalize_spec(arrange_spec) total = 0 for size in range(from_size, to_size + 1): if size <= n: total += factorial(n) // factorial(n - size) return total
[docs] def validate(self, node: GeneratorNode) -> List[str]: """Validate cartesian node specification. Args: node: Cartesian node to validate. Returns: List of error messages. Empty if valid. """ errors = [] stages = node.get(CARTESIAN_KEYWORD) if stages is None: errors.append("Missing _cartesian_ key") return errors if not isinstance(stages, list): errors.append( f"_cartesian_ must be a list, got {type(stages).__name__}" ) return errors # Validate pick/arrange pick = node.get(PICK_KEYWORD) arrange = node.get(ARRANGE_KEYWORD) if pick is not None and arrange is not None: errors.append("Cannot specify both 'pick' and 'arrange'") for key, spec in [(PICK_KEYWORD, pick), (ARRANGE_KEYWORD, arrange)]: if spec is not None and not self._is_valid_size_spec(spec): errors.append( f"{key} must be int or tuple (from, to), " f"got {type(spec).__name__}" ) # Validate count count = node.get(COUNT_KEYWORD) if count is not None and not isinstance(count, int): errors.append(f"count must be an integer, got {type(count).__name__}") return errors
def _is_valid_size_spec(self, spec: Any) -> bool: """Check if a size specification is valid.""" if isinstance(spec, int): return True if isinstance(spec, (tuple, list)) and len(spec) == 2: return all(isinstance(x, int) for x in spec) return False