Source code for nirs4all.pipeline.config._generator.strategies.or_strategy

"""OR strategy for choice-based expansion.

This module handles _or_ nodes that define choices with various selection modes:
- Basic choice: Pick one from alternatives
- pick: Unordered selection (combinations)
- arrange: Ordered arrangement (permutations)
- size: Legacy alias for pick
- Second-order: then_pick, then_arrange, or [outer, inner] syntax
- Constraints: _mutex_, _requires_, _exclude_ for filtering combinations

Syntax examples:
    {"_or_": ["A", "B", "C"]}                    -> "A", "B", "C"
    {"_or_": ["A", "B", "C"], "pick": 2}         -> ["A", "B"], ["A", "C"], ["B", "C"]
    {"_or_": ["A", "B", "C"], "arrange": 2}      -> ["A", "B"], ["B", "A"], ...
    {"_or_": ["A", "B", "C"], "pick": (1, 2)}    -> Pick 1 or 2 items
    {"_or_": [...], "pick": 1, "then_pick": 2}   -> Second-order selection
    {"_or_": [...], "pick": 2, "_mutex_": [["A", "B"]]} -> A and B can't be together
"""

from itertools import combinations, permutations, product
from math import comb, factorial
from typing import Any, FrozenSet, List, Optional, Tuple, Union

from .base import ExpansionStrategy, GeneratorNode, ExpandedResult, SizeSpec
from .registry import register_strategy
from ..keywords import (
    OR_KEYWORD, SIZE_KEYWORD, COUNT_KEYWORD,
    PICK_KEYWORD, ARRANGE_KEYWORD,
    THEN_PICK_KEYWORD, THEN_ARRANGE_KEYWORD,
    MUTEX_KEYWORD, REQUIRES_KEYWORD, EXCLUDE_KEYWORD,
    PURE_OR_KEYS
)
from ..utils.sampling import sample_with_seed


[docs] @register_strategy class OrStrategy(ExpansionStrategy): """Strategy for handling _or_ nodes with selection semantics. Supports: - Basic choice expansion (each alternative becomes a variant) - pick: Unordered selection using combinations - arrange: Ordered arrangement using permutations - size: Legacy alias for pick (backward compatibility) - Second-order selection via then_pick/then_arrange or [outer, inner] - count: Limit number of generated variants - Constraints: _mutex_, _requires_, _exclude_ for filtering (Phase 4) Attributes: keywords: {_or_, size, count, pick, arrange, then_pick, then_arrange, _mutex_, _requires_, _exclude_} priority: 10 (standard priority) """ keywords: FrozenSet[str] = PURE_OR_KEYS priority: int = 10
[docs] @classmethod def handles(cls, node: GeneratorNode) -> bool: """Check if node is a pure OR node. A pure OR node contains _or_ and only OR-related modifier keys. Args: node: Dictionary node to check. Returns: True if node is a pure OR node. """ if not isinstance(node, dict): return False return OR_KEYWORD in node and set(node.keys()).issubset(PURE_OR_KEYS)
[docs] def expand( self, node: GeneratorNode, seed: Optional[int] = None, expand_nested: Optional[callable] = None ) -> ExpandedResult: """Expand an OR node to list of variants. Args: node: OR specification node. seed: Optional seed for random sampling. expand_nested: Callback to expand nested generator nodes. Returns: List of expanded variants. """ choices = node[OR_KEYWORD] size = node.get(SIZE_KEYWORD) pick = node.get(PICK_KEYWORD) arrange = node.get(ARRANGE_KEYWORD) then_pick = node.get(THEN_PICK_KEYWORD) then_arrange = node.get(THEN_ARRANGE_KEYWORD) count = node.get(COUNT_KEYWORD) # Extract constraint specifications (Phase 4) mutex_groups = node.get(MUTEX_KEYWORD, []) requires_groups = node.get(REQUIRES_KEYWORD, []) exclude_combos = node.get(EXCLUDE_KEYWORD, []) # Determine selection mode: arrange > pick > size (backward compat) if arrange is not None: result = self._expand_with_arrange( choices, arrange, then_pick, then_arrange, expand_nested, seed ) elif pick is not None: result = self._expand_with_pick( choices, pick, then_pick, then_arrange, expand_nested, seed ) elif size is not None: # Legacy size behaves like pick (combinations) result = self._expand_with_pick( choices, size, then_pick, then_arrange, expand_nested, seed ) else: # Basic expansion: each choice becomes a variant result = self._expand_basic(choices, expand_nested) # Apply constraints if specified (Phase 4) if mutex_groups or requires_groups or exclude_combos: result = self._apply_constraints( result, mutex_groups, requires_groups, exclude_combos ) # Apply count limit if specified if count is not None and len(result) > count: result = sample_with_seed(result, count, seed=seed) return result
[docs] def count(self, node: GeneratorNode, count_nested: Optional[callable] = None) -> int: """Count OR node variants without generating them. Args: node: OR specification node. count_nested: Callback to count nested nodes. Returns: Number of variants. """ choices = node[OR_KEYWORD] size = node.get(SIZE_KEYWORD) pick = node.get(PICK_KEYWORD) arrange = node.get(ARRANGE_KEYWORD) then_pick = node.get(THEN_PICK_KEYWORD) then_arrange = node.get(THEN_ARRANGE_KEYWORD) count_limit = node.get(COUNT_KEYWORD) # Determine selection mode if arrange is not None: total = self._count_with_arrange( choices, arrange, then_pick, then_arrange, count_nested ) elif pick is not None: total = self._count_with_pick( choices, pick, then_pick, then_arrange, count_nested ) elif size is not None: total = self._count_with_pick( choices, size, then_pick, then_arrange, count_nested ) else: # Basic count: sum of each choice's count total = 0 for choice in choices: if count_nested: total += count_nested(choice) else: total += 1 # Apply count limit if count_limit is not None: return min(count_limit, total) return total
[docs] def validate(self, node: GeneratorNode) -> List[str]: """Validate OR node specification. Args: node: OR node to validate. Returns: List of error messages. Empty if valid. """ errors = [] choices = node.get(OR_KEYWORD) if choices is None: errors.append("Missing _or_ key") return errors if not isinstance(choices, list): errors.append(f"_or_ must be a list, got {type(choices).__name__}") # Validate pick/arrange specs for key in (PICK_KEYWORD, ARRANGE_KEYWORD, SIZE_KEYWORD): if key in node: spec = node[key] if not self._is_valid_size_spec(spec): errors.append( f"{key} must be int, tuple (from, to), or list [outer, inner], " f"got {type(spec).__name__}" ) # Validate count count = node.get(COUNT_KEYWORD) if count is not None and not isinstance(count, int): errors.append(f"count must be an integer, got {type(count).__name__}") return errors
# ------------------------------------------------------------------------- # Basic Expansion # ------------------------------------------------------------------------- def _expand_basic( self, choices: List[Any], expand_nested: Optional[callable] ) -> ExpandedResult: """Expand basic OR (each choice is a variant). Args: choices: List of choice values. expand_nested: Callback to expand nested nodes. Returns: List of expanded variants. """ result = [] for choice in choices: if expand_nested: expanded = expand_nested(choice) result.extend(expanded) else: result.append(choice) return result # ------------------------------------------------------------------------- # Pick Expansion (Combinations) # ------------------------------------------------------------------------- def _expand_with_pick( self, choices: List[Any], pick_spec: SizeSpec, then_pick: Optional[SizeSpec], then_arrange: Optional[SizeSpec], expand_nested: Optional[callable], seed: Optional[int] ) -> ExpandedResult: """Expand using pick (combinations). Args: choices: List of choices. pick_spec: Size specification for pick. then_pick: Optional second-order pick. then_arrange: Optional second-order arrange. expand_nested: Callback for nested expansion. seed: Random seed. Returns: List of combinations. """ # Handle second-order with then_pick if then_pick is not None: return self._handle_pick_then_pick(choices, pick_spec, then_pick) # Handle second-order with then_arrange if then_arrange is not None: return self._handle_pick_then_arrange(choices, pick_spec, then_arrange) # Standard pick expansion # pick_spec can be: int (exact), tuple/list of 2 ints (range from, to) from_size, to_size = self._normalize_spec(pick_spec) result = [] for s in range(from_size, to_size + 1): if s > len(choices): continue if s == 0: result.append([]) continue for combo in combinations(choices, s): combo_results = self._expand_combination(combo, expand_nested) result.extend(combo_results) return result def _count_with_pick( self, choices: List[Any], pick_spec: SizeSpec, then_pick: Optional[SizeSpec], then_arrange: Optional[SizeSpec], count_nested: Optional[callable] ) -> int: """Count pick (combinations) variants. Args: choices: List of choices. pick_spec: Size specification. then_pick: Optional second-order pick. then_arrange: Optional second-order arrange. count_nested: Callback for nested counting. Returns: Number of combinations. """ n = len(choices) # Handle second-order with then_pick if then_pick is not None: return self._count_pick_then_pick(n, pick_spec, then_pick) # Handle second-order with then_arrange if then_arrange is not None: return self._count_pick_then_arrange(n, pick_spec, then_arrange) # Standard count # pick_spec can be: int (exact), tuple/list of 2 ints (range from, to) from_size, to_size = self._normalize_spec(pick_spec) total = 0 for s in range(from_size, to_size + 1): if s <= n: total += comb(n, s) return total # ------------------------------------------------------------------------- # Arrange Expansion (Permutations) # ------------------------------------------------------------------------- def _expand_with_arrange( self, choices: List[Any], arrange_spec: SizeSpec, then_pick: Optional[SizeSpec], then_arrange: Optional[SizeSpec], expand_nested: Optional[callable], seed: Optional[int] ) -> ExpandedResult: """Expand using arrange (permutations). Args: choices: List of choices. arrange_spec: Size specification for arrange. then_pick: Optional second-order pick. then_arrange: Optional second-order arrange. expand_nested: Callback for nested expansion. seed: Random seed. Returns: List of permutations. """ # Handle second-order with then_pick if then_pick is not None: return self._handle_arrange_then_pick(choices, arrange_spec, then_pick) # Handle second-order with then_arrange if then_arrange is not None: return self._handle_arrange_then_arrange(choices, arrange_spec, then_arrange) # Standard arrange expansion # arrange_spec can be: int (exact), tuple/list of 2 ints (range from, to) from_size, to_size = self._normalize_spec(arrange_spec) result = [] for s in range(from_size, to_size + 1): if s > len(choices): continue if s == 0: result.append([]) continue for perm in permutations(choices, s): perm_results = self._expand_combination(perm, expand_nested) result.extend(perm_results) return result def _count_with_arrange( self, choices: List[Any], arrange_spec: SizeSpec, then_pick: Optional[SizeSpec], then_arrange: Optional[SizeSpec], count_nested: Optional[callable] ) -> int: """Count arrange (permutations) variants. Args: choices: List of choices. arrange_spec: Size specification. then_pick: Optional second-order pick. then_arrange: Optional second-order arrange. count_nested: Callback for nested counting. Returns: Number of permutations. """ n = len(choices) # Handle second-order with then_pick if then_pick is not None: return self._count_arrange_then_pick(n, arrange_spec, then_pick) # Handle second-order with then_arrange if then_arrange is not None: return self._count_arrange_then_arrange(n, arrange_spec, then_arrange) # Standard count: P(n, k) = n! / (n-k)! # arrange_spec can be: int (exact), tuple/list of 2 ints (range from, to) from_size, to_size = self._normalize_spec(arrange_spec) total = 0 for s in range(from_size, to_size + 1): if s <= n: total += factorial(n) // factorial(n - s) return total # ------------------------------------------------------------------------- # Second-Order (then_pick / then_arrange) # ------------------------------------------------------------------------- def _handle_pick_then_pick( self, choices: List[Any], primary_spec: SizeSpec, then_spec: SizeSpec ) -> ExpandedResult: """Pick from choices, then pick from results.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Step 1: Generate primary combinations primary_items = [] for s in range(primary_from, primary_to + 1): if s > len(choices): continue for combo in combinations(choices, s): if len(combo) == 1: primary_items.append(combo[0]) else: primary_items.append(list(combo)) # Step 2: Apply then_pick (combinations) result = [] for s in range(then_from, then_to + 1): if s > len(primary_items): continue for combo in combinations(primary_items, s): result.append(list(combo)) return result def _count_pick_then_pick( self, n: int, primary_spec: SizeSpec, then_spec: SizeSpec ) -> int: """Count pick-then-pick combinations.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Count primary combinations total_primary = sum( comb(n, s) for s in range(primary_from, primary_to + 1) if s <= n ) # Count then_pick combinations return sum( comb(total_primary, s) for s in range(then_from, then_to + 1) if s <= total_primary ) def _handle_pick_then_arrange( self, choices: List[Any], primary_spec: SizeSpec, then_spec: SizeSpec ) -> ExpandedResult: """Pick from choices, then arrange results.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Step 1: Generate primary combinations primary_items = [] for s in range(primary_from, primary_to + 1): if s > len(choices): continue for combo in combinations(choices, s): if len(combo) == 1: primary_items.append(combo[0]) else: primary_items.append(list(combo)) # Step 2: Apply then_arrange (permutations) result = [] for s in range(then_from, then_to + 1): if s > len(primary_items): continue for perm in permutations(primary_items, s): result.append(list(perm)) return result def _count_pick_then_arrange( self, n: int, primary_spec: SizeSpec, then_spec: SizeSpec ) -> int: """Count pick-then-arrange permutations.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Count primary combinations total_primary = sum( comb(n, s) for s in range(primary_from, primary_to + 1) if s <= n ) # Count then_arrange permutations return sum( factorial(total_primary) // factorial(total_primary - s) for s in range(then_from, then_to + 1) if s <= total_primary ) def _handle_arrange_then_pick( self, choices: List[Any], primary_spec: SizeSpec, then_spec: SizeSpec ) -> ExpandedResult: """Arrange from choices, then pick from results.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Step 1: Generate primary permutations primary_items = [] for s in range(primary_from, primary_to + 1): if s > len(choices): continue for perm in permutations(choices, s): if len(perm) == 1: primary_items.append(perm[0]) else: primary_items.append(list(perm)) # Step 2: Apply then_pick (combinations) result = [] for s in range(then_from, then_to + 1): if s > len(primary_items): continue for combo in combinations(primary_items, s): result.append(list(combo)) return result def _count_arrange_then_pick( self, n: int, primary_spec: SizeSpec, then_spec: SizeSpec ) -> int: """Count arrange-then-pick combinations.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Count primary permutations total_primary = sum( factorial(n) // factorial(n - s) for s in range(primary_from, primary_to + 1) if s <= n ) # Count then_pick combinations return sum( comb(total_primary, s) for s in range(then_from, then_to + 1) if s <= total_primary ) def _handle_arrange_then_arrange( self, choices: List[Any], primary_spec: SizeSpec, then_spec: SizeSpec ) -> ExpandedResult: """Arrange from choices, then arrange results.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Step 1: Generate primary permutations primary_items = [] for s in range(primary_from, primary_to + 1): if s > len(choices): continue for perm in permutations(choices, s): if len(perm) == 1: primary_items.append(perm[0]) else: primary_items.append(list(perm)) # Step 2: Apply then_arrange (permutations) result = [] for s in range(then_from, then_to + 1): if s > len(primary_items): continue for perm in permutations(primary_items, s): result.append(list(perm)) return result def _count_arrange_then_arrange( self, n: int, primary_spec: SizeSpec, then_spec: SizeSpec ) -> int: """Count arrange-then-arrange permutations.""" primary_from, primary_to = self._normalize_spec(primary_spec) then_from, then_to = self._normalize_spec(then_spec) # Count primary permutations total_primary = sum( factorial(n) // factorial(n - s) for s in range(primary_from, primary_to + 1) if s <= n ) # Count then_arrange permutations return sum( factorial(total_primary) // factorial(total_primary - s) for s in range(then_from, then_to + 1) if s <= total_primary ) # ------------------------------------------------------------------------- # Helper Methods # ------------------------------------------------------------------------- def _normalize_spec(self, spec: SizeSpec) -> Tuple[int, int]: """Normalize size specification to (from, to) tuple. Args: spec: Size specification (int, tuple, or list). Returns: Tuple of (from_size, to_size). Raises: ValueError: If specification is invalid. """ if isinstance(spec, int): return (spec, spec) elif isinstance(spec, (tuple, list)) and len(spec) == 2: return (spec[0], spec[1]) else: raise ValueError(f"Invalid size spec: {spec}. Must be int or (from, to).") def _is_valid_size_spec(self, spec: Any) -> bool: """Check if a size specification is valid.""" if isinstance(spec, int): return True if isinstance(spec, (tuple, list)) and len(spec) == 2: return all(isinstance(x, int) for x in spec) return False def _expand_combination( self, combo: tuple, expand_nested: Optional[callable] ) -> List[List[Any]]: """Expand a combination by taking Cartesian product of expanded elements. Args: combo: Tuple of items from combinations/permutations. expand_nested: Callback to expand nested nodes. Returns: List of expanded combinations. """ if expand_nested: expanded_elements = [expand_nested(item) for item in combo] else: expanded_elements = [[item] for item in combo] # Take Cartesian product results = [] for expanded_combo in product(*expanded_elements): results.append(list(expanded_combo)) return results # ------------------------------------------------------------------------- # Constraint Handling (Phase 4) # ------------------------------------------------------------------------- def _apply_constraints( self, results: ExpandedResult, mutex_groups: List[List[Any]], requires_groups: List[List[Any]], exclude_combos: List[List[Any]] ) -> ExpandedResult: """Apply constraint filters to expanded results. Args: results: List of expanded combinations. mutex_groups: Mutual exclusion groups - items that can't appear together. requires_groups: Dependency pairs - if A is present, B must be too. exclude_combos: Specific combinations to exclude. Returns: Filtered results satisfying all constraints. Examples: >>> # Mutex: A and B can't be together >>> self._apply_constraints( ... [["A","B"], ["A","C"], ["B","C"]], ... mutex_groups=[["A","B"]], requires_groups=[], exclude_combos=[] ... ) [["A","C"], ["B","C"]] """ from ..constraints import apply_all_constraints # Only filter list results (combinations/permutations) # Basic single-choice results don't need constraint filtering if not results or not isinstance(results[0], list): return results return apply_all_constraints( results, mutex_groups=mutex_groups, requires_groups=requires_groups, exclude_combos=exclude_combos )