Source code for nirs4all.pipeline.config._generator.strategies.sample_strategy

"""Sample strategy for statistical sampling.

This module handles _sample_ nodes that generate values using statistical
distributions - useful for random hyperparameter search.

Syntax:
    {"_sample_": {"distribution": "uniform", "from": 0, "to": 1, "num": 10}}
    {"_sample_": {"distribution": "log_uniform", "from": 0.001, "to": 1, "num": 10}}
    {"_sample_": {"distribution": "normal", "mean": 0, "std": 1, "num": 10}}
    {"_sample_": {"distribution": "choice", "values": [...], "num": 5}}

Examples:
    {"_sample_": {"distribution": "uniform", "from": 0.1, "to": 1.0, "num": 5}}
    -> [0.234, 0.567, 0.891, 0.123, 0.456]  (5 random uniform values)
"""

import math
import random
from typing import Any, Dict, FrozenSet, List, Optional

from .base import ExpansionStrategy, GeneratorNode, ExpandedResult
from .registry import register_strategy
from ..keywords import SAMPLE_KEYWORD, COUNT_KEYWORD, SEED_KEYWORD, PURE_SAMPLE_KEYS


[docs] @register_strategy class SampleStrategy(ExpansionStrategy): """Strategy for handling _sample_ nodes. Generates values using statistical sampling from various distributions. Supports uniform, log-uniform, normal, and choice distributions. Supported distributions: - uniform: Uniform distribution between from and to - log_uniform: Log-uniform distribution (common for learning rates) - normal/gaussian: Normal distribution with mean and std - choice: Random selection from a list of values Attributes: keywords: {_sample_, count, seed} priority: 24 (between log_range and range) """ keywords: FrozenSet[str] = PURE_SAMPLE_KEYS priority: int = 24 SUPPORTED_DISTRIBUTIONS = {"uniform", "log_uniform", "normal", "gaussian", "choice"}
[docs] @classmethod def handles(cls, node: GeneratorNode) -> bool: """Check if node is a pure sample node. Args: node: Dictionary node to check. Returns: True if node contains _sample_ and only sample-related keys. """ if not isinstance(node, dict): return False return SAMPLE_KEYWORD in node and set(node.keys()).issubset(PURE_SAMPLE_KEYS)
[docs] def expand( self, node: GeneratorNode, seed: Optional[int] = None, expand_nested: Optional[callable] = None ) -> ExpandedResult: """Expand a sample node to list of sampled values. Args: node: Sample specification node. seed: Optional seed for reproducible sampling. expand_nested: Not typically used for sample nodes. Returns: List of sampled values. Examples: >>> strategy.expand({"_sample_": {"distribution": "uniform", "from": 0, "to": 1, "num": 3}}, seed=42) [0.6394267984578837, 0.025010755222666936, 0.27502931836911926] """ sample_spec = node[SAMPLE_KEYWORD] count_limit = node.get(COUNT_KEYWORD) node_seed = node.get(SEED_KEYWORD, seed) if not isinstance(sample_spec, dict): raise ValueError( f"_sample_ must be a dict, got {type(sample_spec).__name__}" ) # Initialize RNG with seed if node_seed is not None: rng = random.Random(node_seed) else: rng = random distribution = sample_spec.get("distribution", "uniform") num = sample_spec.get("num", 10) # Apply count limit to num if specified if count_limit is not None: num = min(num, count_limit) # Generate samples based on distribution if distribution == "uniform": results = self._sample_uniform(sample_spec, num, rng) elif distribution == "log_uniform": results = self._sample_log_uniform(sample_spec, num, rng) elif distribution in ("normal", "gaussian"): results = self._sample_normal(sample_spec, num, rng) elif distribution == "choice": results = self._sample_choice(sample_spec, num, rng) else: raise ValueError( f"Unknown distribution: {distribution}. " f"Supported: {self.SUPPORTED_DISTRIBUTIONS}" ) return results
def _sample_uniform( self, spec: Dict[str, Any], num: int, rng: random.Random ) -> List[float]: """Sample from uniform distribution.""" low = spec.get("from", 0) high = spec.get("to", 1) return [round(rng.uniform(low, high), 10) for _ in range(num)] def _sample_log_uniform( self, spec: Dict[str, Any], num: int, rng: random.Random ) -> List[float]: """Sample from log-uniform distribution.""" low = spec.get("from", 0.001) high = spec.get("to", 1) if low <= 0 or high <= 0: raise ValueError("Log-uniform requires positive from and to values") log_low = math.log(low) log_high = math.log(high) results = [] for _ in range(num): log_val = rng.uniform(log_low, log_high) results.append(round(math.exp(log_val), 10)) return results def _sample_normal( self, spec: Dict[str, Any], num: int, rng: random.Random ) -> List[float]: """Sample from normal distribution.""" mean = spec.get("mean", 0) std = spec.get("std", 1) return [round(rng.gauss(mean, std), 10) for _ in range(num)] def _sample_choice( self, spec: Dict[str, Any], num: int, rng: random.Random ) -> List[Any]: """Sample from a list of choices.""" values = spec.get("values", []) if not values: return [] return [rng.choice(values) for _ in range(num)]
[docs] def count(self, node: GeneratorNode, count_nested: Optional[callable] = None) -> int: """Count sample results (simply returns num). Args: node: Sample specification node. count_nested: Not used. Returns: Number of samples to generate. """ sample_spec = node[SAMPLE_KEYWORD] count_limit = node.get(COUNT_KEYWORD) if not isinstance(sample_spec, dict): return 0 num = sample_spec.get("num", 10) # Apply count limit if count_limit is not None: return min(count_limit, num) return num
[docs] def validate(self, node: GeneratorNode) -> List[str]: """Validate sample node specification. Args: node: Sample node to validate. Returns: List of error messages. Empty if valid. """ errors = [] sample_spec = node.get(SAMPLE_KEYWORD) if sample_spec is None: errors.append("Missing _sample_ key") return errors if not isinstance(sample_spec, dict): errors.append(f"_sample_ must be a dict, got {type(sample_spec).__name__}") return errors distribution = sample_spec.get("distribution", "uniform") if distribution not in self.SUPPORTED_DISTRIBUTIONS: errors.append( f"Unknown distribution: {distribution}. " f"Supported: {self.SUPPORTED_DISTRIBUTIONS}" ) # Validate distribution-specific parameters if distribution == "uniform": for key in ("from", "to"): if key in sample_spec and not isinstance(sample_spec[key], (int, float)): errors.append(f"uniform '{key}' must be numeric") elif distribution == "log_uniform": for key in ("from", "to"): if key in sample_spec: val = sample_spec[key] if not isinstance(val, (int, float)): errors.append(f"log_uniform '{key}' must be numeric") elif val <= 0: errors.append(f"log_uniform '{key}' must be positive") elif distribution in ("normal", "gaussian"): for key in ("mean", "std"): if key in sample_spec and not isinstance(sample_spec[key], (int, float)): errors.append(f"normal '{key}' must be numeric") if "std" in sample_spec and sample_spec["std"] < 0: errors.append("normal 'std' must be non-negative") elif distribution == "choice": values = sample_spec.get("values") if values is not None and not isinstance(values, list): errors.append("choice 'values' must be a list") # Validate num num = sample_spec.get("num") if num is not None and (not isinstance(num, int) or num < 0): errors.append("num must be a non-negative integer") # Validate count count = node.get(COUNT_KEYWORD) if count is not None and not isinstance(count, int): errors.append(f"count must be an integer, got {type(count).__name__}") return errors