Source code for nirs4all.pipeline.config._generator.core

"""Core expansion logic for the generator module.

This module provides the main expansion and counting functions using the
Strategy pattern for node-specific handling. It serves as the orchestration
layer that dispatches to appropriate strategies.

Main Functions:
    expand_spec(node, seed): Expand a configuration node into all variants
    count_combinations(node): Count variants without generating them

The core module handles:
    - Recursive expansion of nested structures
    - Delegation to strategies for _or_ and _range_ nodes
    - Cartesian product expansion for dicts and lists
    - Mixed nodes (dict with both _or_ and other keys)
"""

from collections.abc import Mapping
from itertools import product
from typing import Any, Dict, List, Optional, Union

from .strategies import get_strategy
from .strategies.base import ExpandedResult
from .keywords import (
    OR_KEYWORD,
    SIZE_KEYWORD,
    COUNT_KEYWORD,
    PICK_KEYWORD,
    ARRANGE_KEYWORD,
    THEN_PICK_KEYWORD,
    THEN_ARRANGE_KEYWORD,
    RANGE_KEYWORD,
    has_or_keyword,
)
from .utils.sampling import sample_with_seed

# Type alias
GeneratorNode = Union[Dict[str, Any], List[Any], str, int, float, bool, None]

# Type for expansion with choices: list of (config, choices) tuples
ExpandedWithChoices = List[tuple]  # List[Tuple[Any, List[Dict[str, Any]]]]


[docs] def expand_spec(node: GeneratorNode, seed: Optional[int] = None) -> ExpandedResult: """Expand a specification node to all possible combinations. This is the main entry point for configuration expansion. It handles all node types and delegates to appropriate strategies for special generator nodes. Args: node: Configuration node to expand. Can be: - dict: Expanded based on keys (strategies or Cartesian product) - list: Cartesian product of expanded elements - scalar: Wrapped in a list seed: Optional random seed for reproducible generation when using 'count' to limit results. Returns: List of expanded variants. Examples: >>> expand_spec({"_or_": ["A", "B"]}) ['A', 'B'] >>> expand_spec({"_range_": [1, 3]}) [1, 2, 3] >>> expand_spec({"x": {"_or_": [1, 2]}, "y": "fixed"}) [{'x': 1, 'y': 'fixed'}, {'x': 2, 'y': 'fixed'}] """ return _expand_internal(node, seed)
[docs] def expand_spec_with_choices( node: GeneratorNode, seed: Optional[int] = None ) -> List[tuple]: """Expand a specification node and track generator choices. Like expand_spec, but also returns the choices made at each generator node (_or_, _range_, etc.) for each expanded variant. This is useful for tracking which specific values were selected to produce each pipeline configuration. Args: node: Configuration node to expand. seed: Optional random seed for reproducible generation. Returns: List of (expanded_config, generator_choices) tuples. Each generator_choices is a list of dicts like: [{"_or_": selected_value}, {"_range_": 18}, ...] in the order they were encountered during expansion. Examples: >>> results = expand_spec_with_choices({"_or_": ["A", "B"]}) >>> results [('A', [{'_or_': 'A'}]), ('B', [{'_or_': 'B'}])] >>> results = expand_spec_with_choices({"x": {"_or_": [1, 2]}, "y": 3}) >>> results [({'x': 1, 'y': 3}, [{'_or_': 1}]), ({'x': 2, 'y': 3}, [{'_or_': 2}])] """ return _expand_with_choices_internal(node, seed)
def _expand_internal(node: GeneratorNode, seed: Optional[int] = None) -> ExpandedResult: """Internal recursive expansion with seed propagation. Args: node: Node to expand. seed: Random seed. Returns: List of expanded variants. """ # Handle lists: Cartesian product of expanded elements if isinstance(node, list): return _expand_list(node, seed) # Handle non-dict types: wrap as single-element list if not isinstance(node, Mapping): return [node] # Try strategy dispatch for pure generator nodes strategy = get_strategy(node) if strategy: return strategy.expand( node, seed=seed, expand_nested=lambda n: _expand_internal(n, seed) ) # Handle mixed dict with _or_ key (has other keys too) if has_or_keyword(node): return _expand_mixed_or_node(node, seed) # Handle dict with _range_ in value position (not a pure range node) if RANGE_KEYWORD in node: # This shouldn't happen for pure range nodes (handled by strategy) # But handle it defensively return _expand_dict(node, seed) # Normal dict: Cartesian product over key values return _expand_dict(node, seed) def _expand_list(node: list, seed: Optional[int]) -> ExpandedResult: """Expand a list by taking Cartesian product of elements. Args: node: List to expand. seed: Random seed. Returns: List of expanded list combinations. Examples: >>> _expand_list([{"_or_": ["A", "B"]}, "C"], None) [['A', 'C'], ['B', 'C']] """ if not node: return [[]] # Empty list -> single empty result # Special case: single element that expands to lists if len(node) == 1: element_result = _expand_internal(node[0], seed) # If result contains lists (combinations), return directly if element_result and isinstance(element_result[0], list): return element_result # Otherwise, fall through to normal processing # Expand each element expanded_elements = [_expand_internal(element, seed) for element in node] # Take Cartesian product results = [] for combo in product(*expanded_elements): results.append(list(combo)) return results def _expand_mixed_or_node(node: Dict[str, Any], seed: Optional[int]) -> ExpandedResult: """Expand a dict that has _or_ mixed with other keys. The strategy is: 1. Separate the base dict (non-generator keys) from the OR node 2. Expand both independently 3. Merge each base variant with each OR variant Args: node: Dict containing _or_ and other keys. seed: Random seed. Returns: List of merged dict variants. """ # Extract modifiers that go with _or_ or_modifier_keys = { "_or_", "size", "count", "pick", "arrange", "then_pick", "then_arrange" } # Separate base keys from OR-related keys base = {k: v for k, v in node.items() if k not in or_modifier_keys} or_node = {k: node[k] for k in or_modifier_keys if k in node} # Expand both parts base_expanded = _expand_internal(base, seed) # list[dict] choice_expanded = _expand_internal(or_node, seed) # list[dict or scalar] # Merge results results = [] for b in base_expanded: for c in choice_expanded: if isinstance(c, Mapping): merged = {**b, **c} results.append(merged) else: # Scalar choices require top-level merge with a key raise ValueError( "Top-level '_or_' choices in a mixed dict must be dicts, " f"not {type(c).__name__}. Got: {c}" ) return results def _expand_dict(node: Dict[str, Any], seed: Optional[int]) -> ExpandedResult: """Expand a regular dict by taking Cartesian product of values. Args: node: Dict to expand. seed: Random seed. Returns: List of dict variants. Examples: >>> _expand_dict({"a": {"_or_": [1, 2]}, "b": 3}, None) [{'a': 1, 'b': 3}, {'a': 2, 'b': 3}] """ if not node: return [{}] # Expand each value keys = [] value_options = [] for k, v in node.items(): keys.append(k) value_options.append(_expand_value(v, seed)) # Take Cartesian product over values results = [] for combo in product(*value_options): result_dict = dict(zip(keys, combo)) results.append(result_dict) return results def _expand_value(v: Any, seed: Optional[int]) -> ExpandedResult: """Expand a value in a dict position. Handles nested generator nodes in value positions. Args: v: Value to expand. seed: Random seed. Returns: List of expanded values. """ if isinstance(v, Mapping): # Check for value-level _or_ or _range_ return _expand_internal(v, seed) elif isinstance(v, list): # Handle lists in value positions return _expand_internal(v, seed) else: # Scalar value return [v] # ============================================================================= # Expansion with Choice Tracking # ============================================================================= # Type for a single result with its choices ResultWithChoices = tuple # Tuple[Any, List[Dict[str, Any]]] def _expand_with_choices_internal( node: GeneratorNode, seed: Optional[int] = None ) -> List[ResultWithChoices]: """Internal recursive expansion that tracks generator choices. Args: node: Node to expand. seed: Random seed. Returns: List of (expanded_value, choices_list) tuples. """ # Handle lists: Cartesian product of expanded elements with merged choices if isinstance(node, list): return _expand_list_with_choices(node, seed) # Handle non-dict types: wrap as single-element list with no choices if not isinstance(node, Mapping): return [(node, [])] # Try strategy dispatch for pure generator nodes strategy = get_strategy(node) if strategy: return _expand_strategy_with_choices(node, strategy, seed) # Handle mixed dict with _or_ key (has other keys too) if has_or_keyword(node): return _expand_mixed_or_with_choices(node, seed) # Normal dict: Cartesian product over key values with merged choices return _expand_dict_with_choices(node, seed) def _expand_strategy_with_choices( node: Dict[str, Any], strategy: Any, seed: Optional[int] ) -> List[ResultWithChoices]: """Expand a generator node using its strategy and track the choice. Args: node: Generator node (_or_, _range_, etc.). strategy: The strategy to use for expansion. seed: Random seed. Returns: List of (value, choices) tuples. """ # Determine the keyword for this generator keyword = _get_generator_keyword(node) # Expand the node expanded = strategy.expand( node, seed=seed, expand_nested=lambda n: _expand_internal(n, seed) ) # Each expanded value gets recorded as a choice results = [] for value in expanded: # The choice records the keyword and the selected value choice = {keyword: value} results.append((value, [choice])) return results def _get_generator_keyword(node: Dict[str, Any]) -> str: """Get the primary generator keyword from a node. Args: node: Generator node. Returns: The keyword string (e.g., "_or_", "_range_"). """ from .keywords import ( OR_KEYWORD, RANGE_KEYWORD, LOG_RANGE_KEYWORD, GRID_KEYWORD, ZIP_KEYWORD, CHAIN_KEYWORD, SAMPLE_KEYWORD, CARTESIAN_KEYWORD ) keyword_priority = [ OR_KEYWORD, RANGE_KEYWORD, LOG_RANGE_KEYWORD, GRID_KEYWORD, ZIP_KEYWORD, CHAIN_KEYWORD, SAMPLE_KEYWORD, CARTESIAN_KEYWORD ] for kw in keyword_priority: if kw in node: return kw return "_unknown_" def _expand_list_with_choices( node: list, seed: Optional[int] ) -> List[ResultWithChoices]: """Expand a list with choice tracking. Args: node: List to expand. seed: Random seed. Returns: List of (list_value, merged_choices) tuples. """ if not node: return [([], [])] # Special case: single element if len(node) == 1: element_results = _expand_with_choices_internal(node[0], seed) # If results contain lists, return directly if element_results and isinstance(element_results[0][0], list): return element_results # Expand each element with choices expanded_elements = [_expand_with_choices_internal(element, seed) for element in node] # Cartesian product with merged choices results = [] for combo in product(*expanded_elements): # combo is tuple of (value, choices) pairs values = [item[0] for item in combo] # Merge all choices in order merged_choices = [] for item in combo: merged_choices.extend(item[1]) results.append((values, merged_choices)) return results def _expand_mixed_or_with_choices( node: Dict[str, Any], seed: Optional[int] ) -> List[ResultWithChoices]: """Expand a mixed OR node with choice tracking. Args: node: Dict containing _or_ and other keys. seed: Random seed. Returns: List of (dict_value, merged_choices) tuples. """ or_modifier_keys = { "_or_", "size", "count", "pick", "arrange", "then_pick", "then_arrange" } base = {k: v for k, v in node.items() if k not in or_modifier_keys} or_node = {k: node[k] for k in or_modifier_keys if k in node} # Expand both parts with choices base_expanded = _expand_dict_with_choices(base, seed) choice_expanded = _expand_with_choices_internal(or_node, seed) # Merge results results = [] for b_val, b_choices in base_expanded: for c_val, c_choices in choice_expanded: if isinstance(c_val, Mapping): merged_val = {**b_val, **c_val} merged_choices = b_choices + c_choices results.append((merged_val, merged_choices)) else: raise ValueError( "Top-level '_or_' choices in a mixed dict must be dicts, " f"not {type(c_val).__name__}. Got: {c_val}" ) return results def _expand_dict_with_choices( node: Dict[str, Any], seed: Optional[int] ) -> List[ResultWithChoices]: """Expand a dict with choice tracking. Args: node: Dict to expand. seed: Random seed. Returns: List of (dict_value, merged_choices) tuples. """ if not node: return [({}, [])] # Expand each value with choices keys = [] value_options = [] for k, v in node.items(): keys.append(k) value_options.append(_expand_value_with_choices(v, seed)) # Cartesian product with merged choices results = [] for combo in product(*value_options): # combo is tuple of (value, choices) pairs values = [item[0] for item in combo] result_dict = dict(zip(keys, values)) # Merge all choices in order merged_choices = [] for item in combo: merged_choices.extend(item[1]) results.append((result_dict, merged_choices)) return results def _expand_value_with_choices( v: Any, seed: Optional[int] ) -> List[ResultWithChoices]: """Expand a value in a dict position with choice tracking. Args: v: Value to expand. seed: Random seed. Returns: List of (value, choices) tuples. """ if isinstance(v, Mapping): return _expand_with_choices_internal(v, seed) elif isinstance(v, list): return _expand_with_choices_internal(v, seed) else: # Scalar value - no choices return [(v, [])] # ============================================================================= # Counting Functions # =============================================================================
[docs] def count_combinations(node: GeneratorNode) -> int: """Calculate total number of combinations without generating them. This is more efficient than generating all combinations when you only need to know the count. Args: node: Configuration node to count. Returns: Number of variants that expand_spec would produce. Examples: >>> count_combinations({"_or_": ["A", "B", "C"]}) 3 >>> count_combinations({"_or_": ["A", "B", "C"], "pick": 2}) 3 # C(3,2) >>> count_combinations({"_range_": [1, 10]}) 10 """ return _count_internal(node)
def _count_internal(node: GeneratorNode) -> int: """Internal recursive counting. Args: node: Node to count. Returns: Number of variants. """ # Handle lists: product of counts if isinstance(node, list): if not node: return 1 # Empty list -> single empty result total = 1 for element in node: total *= _count_internal(element) return total # Scalars return 1 if not isinstance(node, Mapping): return 1 # Try strategy dispatch for pure generator nodes strategy = get_strategy(node) if strategy: return strategy.count(node, count_nested=_count_internal) # Handle mixed dict with _or_ key if has_or_keyword(node): return _count_mixed_or_node(node) # Normal dict: product over key values return _count_dict(node) def _count_mixed_or_node(node: Dict[str, Any]) -> int: """Count mixed OR node. Args: node: Dict containing _or_ and other keys. Returns: Number of variants. """ or_modifier_keys = { "_or_", "size", "count", "pick", "arrange", "then_pick", "then_arrange" } base = {k: v for k, v in node.items() if k not in or_modifier_keys} or_node = {k: node[k] for k in or_modifier_keys if k in node} base_count = _count_internal(base) choice_count = _count_internal(or_node) return base_count * choice_count def _count_dict(node: Dict[str, Any]) -> int: """Count regular dict. Args: node: Dict to count. Returns: Number of variants. """ if not node: return 1 total = 1 for v in node.values(): total *= _count_value(v) return total def _count_value(v: Any) -> int: """Count value-position combinations. Args: v: Value to count. Returns: Number of variants. """ if isinstance(v, Mapping): return _count_internal(v) elif isinstance(v, list): return _count_internal(v) else: return 1