Source code for nirs4all.pipeline.config.component_serialization

import inspect
from enum import Enum
from typing import Any, get_type_hints, get_origin, get_args, Annotated, Union
import importlib
import json

# Simple alias dictionary for common transformations
build_aliases = {
    # Add common aliases here if needed
}


def _is_meta_estimator(obj) -> bool:
    """Check if object is a stacking/voting meta-estimator.

    Args:
        obj: Object to check.

    Returns:
        True if object is a meta-estimator (has estimators and final_estimator).
    """
    try:
        from sklearn.ensemble import (
            StackingRegressor, StackingClassifier,
            VotingRegressor, VotingClassifier
        )
        meta_types = (StackingRegressor, StackingClassifier,
                      VotingRegressor, VotingClassifier)
        return isinstance(obj, meta_types)
    except ImportError:
        return False


def _is_meta_estimator_class(cls) -> bool:
    """Check if class is a meta-estimator type.

    Args:
        cls: Class to check.

    Returns:
        True if class is a meta-estimator type.
    """
    try:
        from sklearn.ensemble import (
            StackingRegressor, StackingClassifier,
            VotingRegressor, VotingClassifier
        )
        return cls in (StackingRegressor, StackingClassifier,
                       VotingRegressor, VotingClassifier)
    except ImportError:
        return False


def _serialize_meta_estimator(obj) -> dict:
    """Serialize a stacking/voting meta-estimator with nested estimators.

    Handles StackingRegressor, StackingClassifier, VotingRegressor, VotingClassifier
    by recursively serializing their base estimators and final_estimator.

    Args:
        obj: Meta-estimator instance to serialize.

    Returns:
        Dictionary with class path, estimators, final_estimator, and other params.
    """
    result = {
        "class": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
        "params": {}
    }

    # Serialize each base estimator as (name, serialized_estimator) pairs
    if hasattr(obj, 'estimators') and obj.estimators is not None:
        result["params"]["estimators"] = [
            [name, serialize_component(est)]
            for name, est in obj.estimators
        ]

    # Serialize final_estimator (for stacking models)
    if hasattr(obj, 'final_estimator') and obj.final_estimator is not None:
        result["params"]["final_estimator"] = serialize_component(obj.final_estimator)

    # Add other changed params (cv, n_jobs, passthrough, etc.)
    other_params = _changed_kwargs(obj)
    for key in ['estimators', 'final_estimator']:
        other_params.pop(key, None)  # Already handled above
    if other_params:
        result["params"].update(serialize_component(other_params))

    return result


def _deserialize_meta_estimator(cls, params: dict) -> Any:
    """Reconstruct a meta-estimator with nested estimators.

    Args:
        cls: Meta-estimator class (StackingRegressor, etc.).
        params: Serialized parameters including estimators and final_estimator.

    Returns:
        Instantiated meta-estimator with deserialized nested estimators.
    """
    deserialized_params = {}

    # Handle estimators list of [name, estimator_config] tuples
    if "estimators" in params:
        deserialized_params["estimators"] = [
            (name, deserialize_component(est_config))
            for name, est_config in params["estimators"]
        ]

    # Handle final_estimator
    if "final_estimator" in params:
        deserialized_params["final_estimator"] = deserialize_component(
            params["final_estimator"]
        )

    # Deserialize other params normally
    for key, value in params.items():
        if key not in ["estimators", "final_estimator"]:
            deserialized_params[key] = deserialize_component(value)

    return cls(**deserialized_params)

[docs] def serialize_component(obj: Any) -> Any: """ Return something that json.dumps can handle. Normalizes all syntaxes to canonical form for hash-based uniqueness. All instances serialize to their internal module paths with only non-default parameters. """ if obj is None or isinstance(obj, (bool, int, float)): return obj if isinstance(obj, str): # Normalize string module paths to internal module paths for hash consistency # e.g., "sklearn.preprocessing.StandardScaler" → "sklearn.preprocessing._data.StandardScaler" if "." in obj and not obj.endswith(('.pkl', '.h5', '.keras', '.joblib', '.pt', '.pth')): try: # Try to import and get canonical internal module path mod_name, _, cls_name = obj.rpartition(".") mod = importlib.import_module(mod_name) cls = getattr(mod, cls_name) # Return canonical form (internal module path) return f"{cls.__module__}.{cls.__qualname__}" except (ImportError, AttributeError): # If import fails, pass through as-is (e.g., controller names, invalid paths) pass return obj # Handle Enum instances - serialize as class path + value if isinstance(obj, Enum): return { "enum": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value } if isinstance(obj, dict): return {k: serialize_component(v) for k, v in obj.items()} if isinstance(obj, list): return [serialize_component(x) for x in obj] if isinstance(obj, tuple): # Convert tuples to lists for YAML/JSON compatibility # Hyperparameter range specifications like ('int', min, max) become ['int', min, max] return [serialize_component(x) for x in obj] if inspect.isclass(obj): return f"{obj.__module__}.{obj.__qualname__}" # Special handling for stacking/ensemble meta-estimators # Must be checked BEFORE generic instance serialization if _is_meta_estimator(obj): return _serialize_meta_estimator(obj) # Handle numpy arrays and other array-like objects # Convert to list for JSON/YAML serialization if hasattr(obj, '__array__') or (hasattr(obj, 'tolist') and hasattr(obj, 'shape')): try: return obj.tolist() except (AttributeError, TypeError): pass params = _changed_kwargs(obj) if inspect.isfunction(obj) or inspect.isbuiltin(obj): func_serialized = { "function": f"{obj.__module__}.{obj.__name__}" } if params: func_serialized["params"] = serialize_component(params) # Store framework as string (not runtime instance) for JSON serialization if hasattr(obj, 'framework'): func_serialized["framework"] = obj.framework return func_serialized def_serialized = f"{obj.__class__.__module__}.{obj.__class__.__qualname__}" if params: def_serialized = { "class": def_serialized, "params": serialize_component(params), } return def_serialized
[docs] def deserialize_component(blob: Any, infer_type: Any = None) -> Any: """Turn the output of serialize_component back into live objects.""" # --- trivial cases ------------------------------------------------------ # if blob is None or isinstance(blob, (bool, int, float)): # Type validation - int and float are considered compatible for numeric values if infer_type is not None and infer_type is not type(None): if not isinstance(blob, infer_type): # Allow int/float cross-compatibility for numeric types if not (isinstance(blob, (int, float)) and infer_type in (int, float)): # Debug-level info only - the value is still returned as-is pass # Removed verbose warning - type mismatch is handled gracefully return blob if isinstance(blob, str): if blob in build_aliases: blob = build_aliases[blob] try: # try to import the module and get the class or function # Safety check for empty or invalid strings if not blob or "." not in blob: return blob mod_name, _, cls_or_func_name = blob.rpartition(".") # Safety check for empty module name if not mod_name: return blob mod = importlib.import_module(mod_name) cls_or_func = getattr(mod, cls_or_func_name) # Try to instantiate without parameters try: return cls_or_func() except TypeError as e: # If instantiation fails due to missing required parameters, # check if there are required parameters without defaults if inspect.isclass(cls_or_func): sig = inspect.signature(cls_or_func.__init__) required_params = [ name for name, param in sig.parameters.items() if name != "self" and param.default is inspect._empty ] if required_params: raise TypeError( f"Cannot deserialize {blob} from string representation: " f"class requires parameters {required_params} but none were provided. " f"This usually means the serialization failed to capture required parameters. " f"Original error: {e}" ) from e raise except (ImportError, AttributeError): return blob if isinstance(blob, list): if infer_type is not None and isinstance(infer_type, type): if issubclass(infer_type, tuple): return tuple(deserialize_component(x) for x in blob) # Handle numpy array deserialization try: import numpy as np is_numpy_type = (infer_type is np.ndarray or (hasattr(infer_type, '__module__') and infer_type.__module__ == 'numpy' and infer_type.__name__ == 'ndarray')) if is_numpy_type: return np.array(blob) except ImportError: pass return [deserialize_component(x) for x in blob] if isinstance(blob, dict): # Handle Enum deserialization if "enum" in blob and "value" in blob: enum_path = blob["enum"] mod_name, _, enum_name = enum_path.rpartition(".") try: mod = importlib.import_module(mod_name) enum_cls = getattr(mod, enum_name) return enum_cls(blob["value"]) except (ImportError, AttributeError, ValueError) as e: print(f"Failed to deserialize enum {enum_path}: {e}") return blob if any(key in blob for key in ("class", "function", "instance")): key = "class" if "class" in blob else "function" if "function" in blob else "instance" # Safety check for empty or None values if not blob[key] or not isinstance(blob[key], str): print(f"Invalid {key} value in blob: {blob[key]}") return blob mod_name, _, cls_or_func_name = blob[key].rpartition(".") # Safety check for empty module name if not mod_name: print(f"Empty module name for {key}: {blob[key]}") return blob try: mod = importlib.import_module(mod_name) cls_or_func = getattr(mod, cls_or_func_name) except (ImportError, AttributeError): print(f"Failed to import {blob[key]}") return blob # Special handling for meta-estimators (stacking/voting) # Must deserialize nested estimators properly if key == "class" and _is_meta_estimator_class(cls_or_func) and "params" in blob: return _deserialize_meta_estimator(cls_or_func, blob["params"]) params = {} if "params" in blob: # print(blob) for k, v in blob["params"].items(): # resolved_type = _resolve_type(cls_or_func, k) # print(k, v, resolved_type) params[k] = deserialize_component(v, _resolve_type(cls_or_func, k)) try: # Special handling for model factory functions with @framework decorator # These need dataset-dependent parameters (like input_shape) so we return # them as dict for controllers to instantiate if key == "function" and hasattr(cls_or_func, 'framework'): # Return dict for controller instantiation (no runtime instance) return { "type": "function", "func": cls_or_func, "framework": cls_or_func.framework, "params": params } if key == "class" or key == "instance" or key == "function": return cls_or_func(**params) # Fallback for other cases if len(params) == 0: return cls_or_func else: return { key: cls_or_func, "params": params } except TypeError: print(f"Failed to instantiate {cls_or_func} with params {params}") sig = inspect.signature(cls_or_func) allowed = {n for n in sig.parameters if n != "self"} filtered = {k: v for k, v in params.items() if k in allowed} # Check again if this is a model factory function if hasattr(cls_or_func, 'framework'): # Return dict for controller instantiation return { "type": "function", "func": cls_or_func, "framework": cls_or_func.framework, "params": filtered } return cls_or_func(**filtered) return {k: deserialize_component(v) for k, v in blob.items()} # should not reach here return blob
def _changed_kwargs(obj): """Return {param: value} for every __init__ param whose current value differs from its default.""" sig = inspect.signature(obj.__class__.__init__) out = {} # Check if object is a Flax module to skip internal fields like 'parent' is_flax_module = False try: import flax.linen as nn if isinstance(obj, nn.Module): is_flax_module = True except ImportError: pass # Get params dict if available (standard sklearn API) obj_params = {} if hasattr(obj, 'get_params'): try: obj_params = obj.get_params(deep=False) except Exception: pass for name, param in sig.parameters.items(): if name == "self": continue if is_flax_module and name == 'parent': continue default = param.default if param.default is not inspect._empty else None try: current = getattr(obj, name) except AttributeError: # Try to get from get_params() if available if name in obj_params: current = obj_params[name] else: # fall back to what's in cvargs if it exists current = obj.__dict__.get("cvargs", {}).get(name, default) # Handle comparison with numpy arrays and other array-like objects try: is_different = current != default # For numpy arrays and similar, convert boolean array to single boolean if hasattr(is_different, '__iter__') and not isinstance(is_different, str): # any(is_different) means at least one element differs is_different = any(is_different) if hasattr(is_different, '__len__') else True except (ValueError, TypeError): # If comparison fails (e.g., array vs None), consider them different is_different = True if is_different: if isinstance(current, tuple): current = list(current) # out[name] = (current, current_type) out[name] = current return out def _resolve_type(obj_or_cls: Any, name: str) -> Union[type, Any, None]: """Resolve the type of a parameter in a class or function based on its signature or type hints. If the parameter is not found, return None. If the parameter has a default value, return its type. If the parameter has no default value, return the type of the attribute with the same name in the class or instance If the parameter is not found in the signature or type hints, return None. """ if obj_or_cls is None: return None cls = obj_or_cls if inspect.isclass(obj_or_cls) else obj_or_cls.__class__ sig = inspect.signature(cls.__init__) if name in sig.parameters: if sig.parameters[name].default is inspect._empty: if sig.parameters[name].annotation is not inspect._empty: # print(f"Using annotation for {name}: {sig.parameters[name].annotation}") ann = sig.parameters[name].annotation while get_origin(ann) is Annotated: ann = get_args(ann)[0] origin = get_origin(ann) if origin is not None: return origin else: return ann else: if hasattr(obj_or_cls, name): return type(getattr(obj_or_cls, name)) else: return None else: return type(sig.parameters[name].default) class_hints = get_type_hints(cls, include_extras=True) if name in class_hints: return class_hints[name] init_hints = get_type_hints(cls.__init__, include_extras=True) init_hints.pop('return', None) if name in init_hints: return init_hints[name] if not inspect.isclass(obj_or_cls) and hasattr(obj_or_cls, name): return type(getattr(obj_or_cls, name)) return None