Source code for nirs4all.pipeline.config._generator.strategies.zip_strategy

"""Zip strategy for parallel iteration.

This module handles _zip_ nodes that iterate over multiple parameter lists
in parallel (like Python's zip function).

Syntax:
    {"_zip_": {"param1": [v1, v2], "param2": [v3, v4]}}

Examples:
    {"_zip_": {"lr": [0.01, 0.1], "batch_size": [16, 32]}}
    -> [{"lr": 0.01, "batch_size": 16}, {"lr": 0.1, "batch_size": 32}]

Unlike _grid_ which generates all combinations, _zip_ pairs values by position.
"""

from typing import Any, Dict, FrozenSet, List, Optional

from .base import ExpansionStrategy, GeneratorNode, ExpandedResult
from .registry import register_strategy
from ..keywords import ZIP_KEYWORD, COUNT_KEYWORD, SEED_KEYWORD, PURE_ZIP_KEYS
from ..utils.sampling import sample_with_seed


[docs] @register_strategy class ZipStrategy(ExpansionStrategy): """Strategy for handling _zip_ nodes. Generates configurations by pairing values at the same index from multiple parameter lists (like Python's zip). Supported formats: - Dict: {"param1": [v1, v2], "param2": [v3, v4]} - With count: Limits output to n random samples Attributes: keywords: {_zip_, count} priority: 28 (between grid and log_range) """ keywords: FrozenSet[str] = PURE_ZIP_KEYS priority: int = 28
[docs] @classmethod def handles(cls, node: GeneratorNode) -> bool: """Check if node is a pure zip node. Args: node: Dictionary node to check. Returns: True if node contains _zip_ and only zip-related keys. """ if not isinstance(node, dict): return False return ZIP_KEYWORD in node and set(node.keys()).issubset(PURE_ZIP_KEYS)
[docs] def expand( self, node: GeneratorNode, seed: Optional[int] = None, expand_nested: Optional[callable] = None ) -> ExpandedResult: """Expand a zip node to list of paired parameter values. Args: node: Zip specification node. seed: Optional seed for random sampling when count is used. expand_nested: Callback to expand nested generator nodes. Returns: List of dicts with paired parameter values. Examples: >>> strategy.expand({"_zip_": {"x": [1, 2, 3], "y": ["A", "B", "C"]}}) [{"x": 1, "y": "A"}, {"x": 2, "y": "B"}, {"x": 3, "y": "C"}] """ zip_spec = node[ZIP_KEYWORD] count = node.get(COUNT_KEYWORD) node_seed = node.get(SEED_KEYWORD, seed) if not isinstance(zip_spec, dict): raise ValueError( f"_zip_ must be a dict of param: values, got {type(zip_spec).__name__}" ) # Handle empty zip if not zip_spec: return [{}] # Expand nested generators and normalize values expanded_zip = {} for key, values in zip_spec.items(): if expand_nested and isinstance(values, dict): expanded_values = expand_nested(values) elif isinstance(values, list): expanded_values = values else: expanded_values = [values] expanded_zip[key] = expanded_values # Get minimum length (zip stops at shortest list) min_len = min(len(v) for v in expanded_zip.values()) if min_len == 0: return [{}] # Generate zipped results keys = list(expanded_zip.keys()) results = [] for i in range(min_len): result_dict = {k: expanded_zip[k][i] for k in keys} results.append(result_dict) # Apply count limit if specified if count is not None and len(results) > count: results = sample_with_seed(results, count, seed=node_seed) return results
[docs] def count(self, node: GeneratorNode, count_nested: Optional[callable] = None) -> int: """Count zip pairs without generating them. Args: node: Zip specification node. count_nested: Callback to count nested nodes. Returns: Number of zipped pairs (minimum list length). """ zip_spec = node[ZIP_KEYWORD] count_limit = node.get(COUNT_KEYWORD) if not isinstance(zip_spec, dict): return 0 if not zip_spec: return 1 # Count based on shortest list lengths = [] for key, values in zip_spec.items(): if count_nested and isinstance(values, dict): lengths.append(count_nested(values)) elif isinstance(values, list): lengths.append(len(values)) else: lengths.append(1) total = min(lengths) if lengths else 0 # Apply count limit if count_limit is not None: return min(count_limit, total) return total
[docs] def validate(self, node: GeneratorNode) -> List[str]: """Validate zip node specification. Args: node: Zip node to validate. Returns: List of error messages. Empty if valid. """ errors = [] zip_spec = node.get(ZIP_KEYWORD) if zip_spec is None: errors.append("Missing _zip_ key") return errors if not isinstance(zip_spec, dict): errors.append(f"_zip_ must be a dict, got {type(zip_spec).__name__}") return errors # Check for consistent lengths (warning, not error) lengths = {} for key, values in zip_spec.items(): if isinstance(values, list): lengths[key] = len(values) if lengths and len(set(lengths.values())) > 1: # Different lengths - warn pass # Could add warning here # Validate count count = node.get(COUNT_KEYWORD) if count is not None and not isinstance(count, int): errors.append(f"count must be an integer, got {type(count).__name__}") return errors