Source code for nirs4all.pipeline.config.pipeline_config

"""
PipelineConfigs.py
"""

import json
import logging
from pathlib import Path
from typing import List, Any, Dict, Union
import yaml

from .component_serialization import serialize_component
from .generator import expand_spec, expand_spec_with_choices, count_combinations


class _ShortNameFormatter(logging.Formatter):
    """Formatter that strips 'nirs4all.' prefix from logger names."""

    def format(self, record: logging.LogRecord) -> str:
        # Strip nirs4all prefix for cleaner output
        if record.name.startswith("nirs4all."):
            record.name = record.name[9:]  # len("nirs4all.") = 9
        return super().format(record)


# Configure logging with simplified module names
_handler = logging.StreamHandler()
_handler.setFormatter(_ShortNameFormatter("%(levelname)s: %(name)s: %(message)s"))
logging.root.addHandler(_handler)
logging.root.setLevel(logging.INFO)

logger = logging.getLogger(__name__)

[docs] class PipelineConfigs: """ Class to hold the configuration for a pipeline. """ def __init__(self, definition: Union[Dict, List[Any], str], name: str = "", description: str = "No description provided", max_generation_count: int = 10000): """ Initialize the pipeline configuration. """ ## Parse / Format / Validate the configuration self.description = description self.steps = self._load_steps(definition) self.steps = self._preprocess_steps(self.steps) self.steps = serialize_component(self.steps) ## Generation self.has_configurations = False self.generator_choices: List[List[Dict[str, Any]]] = [] # Choices for each pipeline was_expanded = False if self._has_gen_keys(self.steps): count = count_combinations(self.steps) if count > max_generation_count: raise ValueError(f"Configuration expansion would generate {count} configurations, exceeding the limit of {max_generation_count}. Please simplify your configuration.") # Always expand generator syntax, even if count=1 # The _or_, _range_ etc. must be replaced with actual values if count >= 1: self.has_configurations = count > 1 # Use expand_spec_with_choices to track generator choices expanded_with_choices = expand_spec_with_choices(self.steps) self.steps = [config for config, choices in expanded_with_choices] self.generator_choices = [choices for config, choices in expanded_with_choices] was_expanded = True if not was_expanded: self.steps = [self.steps] # Wrap single configuration in a list self.generator_choices = [[]] # No choices for single config ## Name and hash if name == "": name = "config" self.names = [ name + "_" + self.get_hash(steps)[0:6] for steps in self.steps ] # print(f"✅ {len(self.steps)} pipeline configuration(s).") @staticmethod def _has_gen_keys(obj: Any, skip_branch: bool = True) -> bool: """Recursively check if the configuration contains generator keys. Args: obj: Configuration object to check skip_branch: If True, skip generator detection inside 'branch' keys (these are handled by BranchController at runtime) Returns: True if generator keys are found at the pipeline level """ if isinstance(obj, dict): # Skip content inside 'branch' key - BranchController handles those if skip_branch and "branch" in obj: # Check other keys but skip branch content return any( PipelineConfigs._has_gen_keys(v, skip_branch) for k, v in obj.items() if k != "branch" ) if "_or_" in obj or "_range_" in obj: return True return any(PipelineConfigs._has_gen_keys(v, skip_branch) for v in obj.values()) elif isinstance(obj, list): return any(PipelineConfigs._has_gen_keys(item, skip_branch) for item in obj) return False @staticmethod def _preprocess_steps(steps: Any) -> Any: """ Preprocess steps to merge *_params into the corresponding component key. Recursively handles lists and dicts. """ if isinstance(steps, list): return [PipelineConfigs._preprocess_steps(step) for step in steps] elif isinstance(steps, dict): # Find all XX/XX_params pairs and merge them result = steps.copy() # Find all keys ending with '_params' params_keys = [k for k in result.keys() if k.endswith('_params')] for params_key in params_keys: # Get the base key (remove '_params' suffix) base_key = params_key[:-7] # Remove '_params' if base_key in result: # Merge base_key and params_key into standard format base_value = result[base_key] params_value = result[params_key] # Convert to standard {"class": ..., "params": ...} format result[base_key] = { "class": base_value, "params": params_value } # Remove the params key del result[params_key] # Also normalize bare classes in component-like keys to {"class": ...} format # This ensures consistent serialization for cases like {"y_processing": MinMaxScaler} import inspect for key, value in list(result.items()): # If the value is a class and the key looks like a component key # (not "class" or "params" which have special meaning) if (inspect.isclass(value) and key not in ["class", "params"] and not key.endswith("_params")): result[key] = {"class": value} # Also handle direct {"class": ClassObject, "params": {...}} format # to ensure consistent serialization if "class" in result and "params" in result: # This is already in the right structure, just ensure class gets serialized properly pass # Recurse on values for k, v in result.items(): result[k] = PipelineConfigs._preprocess_steps(v) return result else: return steps @staticmethod def _load_steps(definition: Union[Dict, List[Any], str]) -> List[Any]: """ Load steps from a definition which can be a dict, list, or string. """ if isinstance(definition, str): return PipelineConfigs._load_str_steps(definition) elif isinstance(definition, list): return definition elif isinstance(definition, dict): if "pipeline" in definition: return definition["pipeline"] else: raise ValueError("Invalid pipeline definition format. Expected a list, dict with 'pipeline' key, or string.") else: raise TypeError("Pipeline definition must be a list, dict, or string.") @staticmethod def _load_str_steps(definition: str) -> List[Any]: """Load steps from a string definition which can be a JSON or YAML file path, or a JSON/YAML string. Args: definition: Path to a JSON/YAML file, or a JSON/YAML string. Returns: List of pipeline steps. Raises: FileNotFoundError: If the config file doesn't exist. ValueError: If the config file has invalid JSON/YAML syntax. """ if definition.endswith('.json') or definition.endswith('.yaml') or definition.endswith('.yml'): if not Path(definition).is_file(): raise FileNotFoundError( f"Configuration file does not exist: {definition}\n" f"Please check the file path and try again." ) pipeline_definition = None if definition.endswith('.json'): try: with open(definition, 'r', encoding='utf-8') as f: pipeline_definition = json.load(f) except json.JSONDecodeError as exc: # Provide detailed error message with line number raise ValueError( f"Invalid JSON in {definition}\n" f"Error at line {exc.lineno}, column {exc.colno}:\n" f" {exc.msg}\n\n" f"Common JSON issues:\n" f" - Missing or extra commas\n" f" - Unquoted strings\n" f" - Trailing commas (not allowed in JSON)\n" f" - Single quotes instead of double quotes" ) from exc elif definition.endswith('.yaml') or definition.endswith('.yml'): try: with open(definition, 'r', encoding='utf-8') as f: pipeline_definition = yaml.safe_load(f) except yaml.YAMLError as exc: # Extract line number from YAML error if available if hasattr(exc, 'problem_mark') and exc.problem_mark: mark = exc.problem_mark line_num = mark.line + 1 col_num = mark.column + 1 problem = getattr(exc, 'problem', 'Unknown error') context = getattr(exc, 'context', '') error_details = f" {problem}" if context: error_details += f" ({context})" raise ValueError( f"Invalid YAML in {definition}\n" f"Error at line {line_num}, column {col_num}:\n" f"{error_details}\n\n" f"Common YAML issues:\n" f" - Incorrect indentation (use spaces, not tabs)\n" f" - Missing colon after key names\n" f" - Unescaped special characters\n" f" - Mixing tabs and spaces" ) from exc else: raise ValueError( f"Invalid YAML in {definition}:\n" f" {exc}\n\n" f"Please check your YAML syntax." ) from exc else: try: pipeline_definition = json.loads(definition) except json.JSONDecodeError as exc: try: return yaml.safe_load(definition) except yaml.YAMLError as exc2: raise ValueError( "Invalid pipeline definition string.\n" "Must be a valid JSON or YAML format.\n\n" f"JSON error: {exc.msg} at line {exc.lineno}\n" f"YAML error: {exc2}" ) from exc2 if not pipeline_definition: raise ValueError( f"Pipeline definition is empty or invalid.\n" f"The configuration file must contain a 'pipeline' key with a list of steps." ) return PipelineConfigs._load_steps(pipeline_definition)
[docs] @staticmethod def get_hash(steps) -> str: """ Generate a hash for the pipeline configuration. All objects are fully JSON-serializable (no _runtime_instance). No need for default=str hack anymore. """ import hashlib serializable = json.dumps(steps, sort_keys=True).encode('utf-8') return hashlib.md5(serializable).hexdigest()[0:8]
@staticmethod def _get_step_description(step: Any) -> str: """Get a human-readable description of a step""" if step is None: return "No operation" if isinstance(step, dict): if len(step) == 1: key = next(iter(step.keys())) return f"{key}" elif "class" in step: key = f"{step['class'].split('.')[-1]}" if "params" in step: params_str = ", ".join(f"{k}={v}" for k, v in step["params"].items()) return f"{key}({params_str})" return f"{step['class'].split('.')[-1]}" elif "model" in step: # Check for custom model name first if "name" in step: custom_name = step["name"] actions = "train" if "finetune_params" in step: actions = "(finetune)" return f"{actions} {custom_name}" # Use model class name if no custom name if "class" in step['model']: key = f"{step['model']['class'].split('.')[-1]}" elif "function" in step['model']: key = f"{step['model']['function'].split('.')[-1]}" else: key = "unknown_model" params_str = "" if "params" in step['model']: params_str = ", ".join(f"{k}={v}" for k, v in step['model']["params"].items()) actions = "train" if "finetune_params" in step: actions = "(finetune)" return f"{actions} {key}({params_str})" else: return f"Dict with {len(step)} keys" elif isinstance(step, list): return f"Sub-pipeline ({len(step)} steps)" elif isinstance(step, str): return step else: return str(type(step).__name__)
[docs] @classmethod def value_of(cls, obj, key): """ Recursively collect all values of a key in a (possibly nested) serialized object. Returns a single string with values joined by commas. """ values = [] if isinstance(obj, dict): for k, v in obj.items(): if k == key: values.append(str(v)) values.extend(cls.value_of(v, key)) elif isinstance(obj, list): for item in obj: values.extend(cls.value_of(item, key)) return values
[docs] @classmethod def value_of_str(cls, obj, key): """ Returns a single string of all values for the given key, joined by commas. """ return ", ".join(cls.value_of(obj, key))