Source code for nirs4all.pipeline.config._generator.strategies.grid_strategy

"""Grid strategy for grid-search style parameter expansion.

This module handles _grid_ nodes that generate Cartesian products of
parameter spaces - useful for hyperparameter grid search.

Syntax:
    {"_grid_": {"param1": [...], "param2": [...]}}

Examples:
    {"_grid_": {"lr": [0.01, 0.1], "batch_size": [16, 32]}}
    -> [{"lr": 0.01, "batch_size": 16}, {"lr": 0.01, "batch_size": 32},
        {"lr": 0.1, "batch_size": 16}, {"lr": 0.1, "batch_size": 32}]
"""

from itertools import product
from typing import Any, Dict, FrozenSet, List, Optional

from .base import ExpansionStrategy, GeneratorNode, ExpandedResult
from .registry import register_strategy
from ..keywords import GRID_KEYWORD, COUNT_KEYWORD, SEED_KEYWORD, PURE_GRID_KEYS
from ..utils.sampling import sample_with_seed


[docs] @register_strategy class GridStrategy(ExpansionStrategy): """Strategy for handling _grid_ nodes. Generates all combinations (Cartesian product) of parameter values. Similar to sklearn's ParameterGrid. Supported formats: - Dict: {"param1": [v1, v2], "param2": [v3, v4]} - With count: Limits output to n random samples Attributes: keywords: {_grid_, count} priority: 30 (checked early due to specific structure) """ keywords: FrozenSet[str] = PURE_GRID_KEYS priority: int = 30 # High priority
[docs] @classmethod def handles(cls, node: GeneratorNode) -> bool: """Check if node is a pure grid node. Args: node: Dictionary node to check. Returns: True if node contains _grid_ and only grid-related keys. """ if not isinstance(node, dict): return False return GRID_KEYWORD in node and set(node.keys()).issubset(PURE_GRID_KEYS)
[docs] def expand( self, node: GeneratorNode, seed: Optional[int] = None, expand_nested: Optional[callable] = None ) -> ExpandedResult: """Expand a grid node to list of parameter combinations. Args: node: Grid specification node. seed: Optional seed for random sampling when count is used. expand_nested: Callback to expand nested generator nodes. Returns: List of dicts with all parameter combinations. Examples: >>> strategy.expand({"_grid_": {"x": [1, 2], "y": ["A", "B"]}}) [{"x": 1, "y": "A"}, {"x": 1, "y": "B"}, {"x": 2, "y": "A"}, {"x": 2, "y": "B"}] """ grid_spec = node[GRID_KEYWORD] count = node.get(COUNT_KEYWORD) node_seed = node.get(SEED_KEYWORD, seed) if not isinstance(grid_spec, dict): raise ValueError( f"_grid_ must be a dict of param: values, got {type(grid_spec).__name__}" ) # Handle empty grid if not grid_spec: return [{}] # Expand nested generators in values first expanded_grid = {} for key, values in grid_spec.items(): if expand_nested and isinstance(values, (dict, list)): # If value is a generator node or list, expand it if isinstance(values, dict): expanded_values = expand_nested(values) else: # For lists, we assume they're the parameter values directly # unless they contain generator nodes if any(isinstance(v, dict) for v in values): expanded_values = [] for v in values: if isinstance(v, dict): expanded_values.extend(expand_nested(v)) else: expanded_values.append(v) else: expanded_values = values else: if not isinstance(values, list): values = [values] expanded_values = values expanded_grid[key] = expanded_values # Generate Cartesian product keys = list(expanded_grid.keys()) value_lists = [expanded_grid[k] for k in keys] results = [] for combo in product(*value_lists): result_dict = dict(zip(keys, combo)) results.append(result_dict) # Apply count limit if specified if count is not None and len(results) > count: results = sample_with_seed(results, count, seed=node_seed) return results
[docs] def count(self, node: GeneratorNode, count_nested: Optional[callable] = None) -> int: """Count grid combinations without generating them. Args: node: Grid specification node. count_nested: Callback to count nested nodes. Returns: Number of parameter combinations. """ grid_spec = node[GRID_KEYWORD] count_limit = node.get(COUNT_KEYWORD) if not isinstance(grid_spec, dict): return 0 if not grid_spec: return 1 # Empty grid produces one empty result # Count total combinations total = 1 for key, values in grid_spec.items(): if count_nested and isinstance(values, dict): val_count = count_nested(values) elif isinstance(values, list): val_count = len(values) else: val_count = 1 total *= val_count # Apply count limit if count_limit is not None: return min(count_limit, total) return total
[docs] def validate(self, node: GeneratorNode) -> List[str]: """Validate grid node specification. Args: node: Grid node to validate. Returns: List of error messages. Empty if valid. """ errors = [] grid_spec = node.get(GRID_KEYWORD) if grid_spec is None: errors.append("Missing _grid_ key") return errors if not isinstance(grid_spec, dict): errors.append( f"_grid_ must be a dict, got {type(grid_spec).__name__}" ) return errors # Validate each parameter for key, values in grid_spec.items(): if not isinstance(key, str): errors.append(f"Grid keys must be strings, got {type(key).__name__}") if not isinstance(values, (list, dict)): # Allow scalar values (treated as single-item list) pass # Validate count count = node.get(COUNT_KEYWORD) if count is not None and not isinstance(count, int): errors.append(f"count must be an integer, got {type(count).__name__}") return errors