Source code for nirs4all.pipeline.config._generator.presets

"""Preset system for named configuration templates.

This module provides a preset registry for storing and retrieving
named configuration templates. Presets allow users to define reusable
configuration patterns and reference them by name.

Usage:
    # Register a preset
    register_preset("spectral_transforms", {
        "_or_": ["SNV", "MSC", "Detrend"],
        "pick": 2
    })

    # Use preset in configuration
    config = {"transforms": {"_preset_": "spectral_transforms"}}

    # Expand with preset resolution
    expanded = expand_spec(config)  # Automatically resolves presets

Keywords:
    _preset_: Reference to a named preset configuration

Examples:
    # Define preprocessing options
    register_preset("standard_preprocessing", {
        "_or_": [
            {"class": "StandardScaler"},
            {"class": "MinMaxScaler"},
            {"class": "RobustScaler"}
        ]
    })

    # Define model options
    register_preset("regression_models", {
        "_or_": [
            {"class": "PLSRegression", "n_components": {"_range_": [5, 20]}},
            {"class": "RandomForestRegressor", "n_estimators": 100}
        ]
    })

    # Use in pipeline
    pipeline_spec = {
        "preprocessing": {"_preset_": "standard_preprocessing"},
        "model": {"_preset_": "regression_models"}
    }
"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

# Global preset registry
_PRESET_REGISTRY: Dict[str, Any] = {}

# Keyword for preset reference
PRESET_KEYWORD: str = "_preset_"


[docs] def register_preset( name: str, spec: Any, description: Optional[str] = None, tags: Optional[List[str]] = None, overwrite: bool = False ) -> None: """Register a named preset configuration. Args: name: Unique name for the preset. spec: Configuration specification (dict, list, or scalar). description: Optional human-readable description. tags: Optional list of tags for categorization. overwrite: If True, overwrite existing preset with same name. Raises: ValueError: If preset name already exists and overwrite=False. Examples: >>> register_preset("my_models", {"_or_": ["PLS", "RF"]}) >>> register_preset("my_models", {"_or_": ["SVM"]}, overwrite=True) """ if name in _PRESET_REGISTRY and not overwrite: raise ValueError( f"Preset '{name}' already exists. Use overwrite=True to replace." ) _PRESET_REGISTRY[name] = { 'spec': deepcopy(spec), 'description': description, 'tags': tags or [], }
[docs] def unregister_preset(name: str) -> bool: """Remove a preset from the registry. Args: name: Name of preset to remove. Returns: True if preset was removed, False if it didn't exist. """ if name in _PRESET_REGISTRY: del _PRESET_REGISTRY[name] return True return False
[docs] def get_preset(name: str) -> Any: """Retrieve a preset specification by name. Args: name: Name of the preset. Returns: Deep copy of the preset specification. Raises: KeyError: If preset doesn't exist. """ if name not in _PRESET_REGISTRY: raise KeyError(f"Preset '{name}' not found. Available: {list_presets()}") return deepcopy(_PRESET_REGISTRY[name]['spec'])
[docs] def get_preset_info(name: str) -> Dict[str, Any]: """Get full preset info including metadata. Args: name: Name of the preset. Returns: Dict with spec, description, tags. Raises: KeyError: If preset doesn't exist. """ if name not in _PRESET_REGISTRY: raise KeyError(f"Preset '{name}' not found.") return deepcopy(_PRESET_REGISTRY[name])
[docs] def list_presets(tags: Optional[List[str]] = None) -> List[str]: """List all registered preset names. Args: tags: If provided, filter to presets with any of these tags. Returns: List of preset names. """ if tags is None: return list(_PRESET_REGISTRY.keys()) result = [] tag_set = set(tags) for name, info in _PRESET_REGISTRY.items(): if tag_set & set(info.get('tags', [])): result.append(name) return result
[docs] def clear_presets() -> int: """Clear all registered presets. Returns: Number of presets cleared. """ count = len(_PRESET_REGISTRY) _PRESET_REGISTRY.clear() return count
[docs] def has_preset(name: str) -> bool: """Check if a preset exists. Args: name: Name to check. Returns: True if preset exists. """ return name in _PRESET_REGISTRY
[docs] def is_preset_reference(node: Any) -> bool: """Check if a node is a preset reference. Args: node: Node to check. Returns: True if node is a dict with _preset_ key. """ return isinstance(node, dict) and PRESET_KEYWORD in node
[docs] def resolve_preset(node: Dict[str, Any]) -> Any: """Resolve a single preset reference. Args: node: Dict containing _preset_ key. Returns: Resolved preset specification. Raises: KeyError: If referenced preset doesn't exist. ValueError: If _preset_ value is not a string. """ preset_name = node.get(PRESET_KEYWORD) if not isinstance(preset_name, str): raise ValueError( f"_preset_ must be a string, got {type(preset_name).__name__}" ) return get_preset(preset_name)
[docs] def resolve_presets_recursive(node: Any, resolved: Optional[Set[str]] = None) -> Any: """Recursively resolve all preset references in a configuration. Handles circular reference detection. Args: node: Configuration node (dict, list, or scalar). resolved: Set of already-resolved presets (for cycle detection). Returns: Node with all preset references resolved. Raises: ValueError: If circular preset reference detected. """ if resolved is None: resolved = set() # Handle preset reference if is_preset_reference(node): preset_name = node[PRESET_KEYWORD] if preset_name in resolved: raise ValueError( f"Circular preset reference detected: {preset_name}" ) resolved.add(preset_name) preset_spec = resolve_preset(node) # Recursively resolve nested presets return resolve_presets_recursive(preset_spec, resolved.copy()) # Handle dict if isinstance(node, dict): return { k: resolve_presets_recursive(v, resolved.copy()) for k, v in node.items() } # Handle list if isinstance(node, list): return [resolve_presets_recursive(item, resolved.copy()) for item in node] # Scalar - return as-is return node
[docs] def export_presets() -> Dict[str, Any]: """Export all presets for serialization. Returns: Dict of all presets with metadata. """ return deepcopy(_PRESET_REGISTRY)
[docs] def import_presets( presets: Dict[str, Any], overwrite: bool = False ) -> int: """Import presets from a dict. Args: presets: Dict mapping preset names to info dicts or specs. overwrite: If True, overwrite existing presets. Returns: Number of presets imported. """ count = 0 for name, value in presets.items(): if isinstance(value, dict) and 'spec' in value: # Full info dict register_preset( name, value['spec'], description=value.get('description'), tags=value.get('tags'), overwrite=overwrite ) else: # Direct spec register_preset(name, value, overwrite=overwrite) count += 1 return count
# ============================================================================= # Built-in Presets # =============================================================================
[docs] def register_builtin_presets() -> None: """Register built-in preset configurations. These are common patterns that users might want to use. """ # Standard scaler options register_preset( "standard_scalers", { "_or_": [ {"class": "sklearn.preprocessing.StandardScaler"}, {"class": "sklearn.preprocessing.MinMaxScaler"}, {"class": "sklearn.preprocessing.RobustScaler"}, None # No scaling option ] }, description="Common sklearn scalers including no-scaling option", tags=["preprocessing", "sklearn"], overwrite=True ) # Common PLS component ranges register_preset( "pls_components", {"_range_": [2, 20]}, description="Common range for PLS n_components", tags=["hyperparameter", "pls"], overwrite=True ) # Learning rate schedules register_preset( "learning_rates", {"_log_range_": [0.0001, 0.1, 10]}, description="Logarithmic range of learning rates", tags=["hyperparameter", "deep_learning"], overwrite=True )