Source code for nirs4all.controllers.models.stacking.serialization

"""
Meta-Model Serialization - Artifact persistence for meta-model stacking.

This module provides dataclasses and utilities for persisting meta-model
artifacts with complete source model dependency tracking.

Phase 3 Implementation - Key components:
1. SourceModelReference: Reference to a source model with feature mapping
2. MetaModelArtifact: Complete artifact for meta-model persistence
3. MetaModelSerializer: Handles serialization/deserialization

The meta-model serialization captures:
- The trained meta-learner itself (via artifact_registry)
- Ordered references to source models (for feature column alignment)
- Stacking configuration (coverage strategy, aggregation, etc.)
- Branch context (for validation during prediction)
"""

from collections import Counter
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import json
import warnings

from nirs4all.operators.models.meta import (
    StackingConfig,
    CoverageStrategy,
    TestAggregation,
    BranchScope,
)

if TYPE_CHECKING:
    from nirs4all.pipeline.storage.artifacts.types import ArtifactRecord, MetaModelConfig
    from nirs4all.operators.models.meta import MetaModel
    from nirs4all.operators.models.selection import ModelCandidate
    from nirs4all.pipeline.config.context import ExecutionContext
    from .reconstructor import ReconstructionResult


[docs] @dataclass class SourceModelReference: """Reference to a source model used in stacking. Stores all information needed to locate and validate a source model during prediction mode. Attributes: model_name: Display name of the model (e.g., "PLSRegression"). model_classname: Full class name (e.g., "sklearn.cross_decomposition.PLSRegression"). step_idx: Pipeline step index where the model was trained. artifact_id: Unique artifact ID for loading the model binary. feature_index: Column index in meta-features matrix. fold_id: Optional fold ID if fold-specific reference. branch_id: Branch ID where model was trained. branch_name: Branch name where model was trained. branch_path: Full branch path for nested branches. val_score: Validation score for weighted averaging. metric: Metric used for scoring (e.g., "r2", "rmse"). Example: >>> ref = SourceModelReference( ... model_name="PLSRegression", ... model_classname="sklearn.cross_decomposition.PLSRegression", ... step_idx=3, ... artifact_id="0001:3:all", ... feature_index=0, ... branch_id=None, ... val_score=0.92, ... metric="r2" ... ) """ model_name: str model_classname: str step_idx: int artifact_id: str feature_index: int fold_id: Optional[str] = None branch_id: Optional[int] = None branch_name: Optional[str] = None branch_path: Optional[List[int]] = None val_score: Optional[float] = None metric: Optional[str] = None
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON/YAML serialization.""" return { "model_name": self.model_name, "model_classname": self.model_classname, "step_idx": self.step_idx, "artifact_id": self.artifact_id, "feature_index": self.feature_index, "fold_id": self.fold_id, "branch_id": self.branch_id, "branch_name": self.branch_name, "branch_path": self.branch_path, "val_score": self.val_score, "metric": self.metric, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SourceModelReference": """Create from dictionary.""" return cls( model_name=data.get("model_name", ""), model_classname=data.get("model_classname", ""), step_idx=data.get("step_idx", 0), artifact_id=data.get("artifact_id", ""), feature_index=data.get("feature_index", 0), fold_id=data.get("fold_id"), branch_id=data.get("branch_id"), branch_name=data.get("branch_name"), branch_path=data.get("branch_path"), val_score=data.get("val_score"), metric=data.get("metric"), )
[docs] @dataclass class MetaModelArtifact: """Complete artifact for meta-model persistence. Contains all information needed to: - Reload the meta-model and its dependencies - Reconstruct feature columns in the correct order - Validate branch context during prediction - Apply the same stacking configuration Attributes: meta_model_type: Type identifier ("MetaModel"). meta_model_name: Display name of the meta-model. meta_learner_class: Class name of the meta-learner (e.g., "Ridge"). source_models: Ordered list of source model references. feature_columns: Feature column names in order. stacking_config: Serialized stacking configuration. selector_config: Configuration of the model selector used. branch_context: Branch context during training. use_proba: Whether probability features were used. n_folds: Number of cross-validation folds. coverage_ratio: OOF coverage ratio achieved during training. artifact_id: The artifact ID for the meta-model itself. training_timestamp: ISO timestamp of training. Example: >>> artifact = MetaModelArtifact( ... meta_model_type="MetaModel", ... meta_model_name="MetaModel_Ridge", ... meta_learner_class="Ridge", ... source_models=[ref1, ref2], ... feature_columns=["PLS_pred", "RF_pred"], ... stacking_config=stacking_config_dict, ... branch_context={"branch_id": None}, ... use_proba=False, ... n_folds=5, ... coverage_ratio=1.0, ... artifact_id="0001:5:all", ... training_timestamp="2024-12-12T14:30:22Z" ... ) """ meta_model_type: str meta_model_name: str meta_learner_class: str source_models: List[SourceModelReference] feature_columns: List[str] stacking_config: Dict[str, Any] selector_config: Optional[Dict[str, Any]] = None branch_context: Optional[Dict[str, Any]] = None use_proba: bool = False n_folds: int = 0 coverage_ratio: float = 1.0 artifact_id: str = "" training_timestamp: str = field( default_factory=lambda: datetime.now(timezone.utc).isoformat() ) # Phase 5: Classification support fields task_type: str = "regression" # "regression", "binary_classification", "multiclass_classification" n_classes: Optional[int] = None # Number of classes for classification tasks feature_to_model_mapping: Optional[Dict[str, str]] = None # Feature name -> source model name
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON/YAML serialization.""" return { "meta_model_type": self.meta_model_type, "meta_model_name": self.meta_model_name, "meta_learner_class": self.meta_learner_class, "source_models": [ref.to_dict() for ref in self.source_models], "feature_columns": self.feature_columns, "stacking_config": self.stacking_config, "selector_config": self.selector_config, "branch_context": self.branch_context, "use_proba": self.use_proba, "n_folds": self.n_folds, "coverage_ratio": self.coverage_ratio, "artifact_id": self.artifact_id, "training_timestamp": self.training_timestamp, # Phase 5: Classification support "task_type": self.task_type, "n_classes": self.n_classes, "feature_to_model_mapping": self.feature_to_model_mapping, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MetaModelArtifact": """Create from dictionary.""" source_models = [ SourceModelReference.from_dict(ref) for ref in data.get("source_models", []) ] return cls( meta_model_type=data.get("meta_model_type", "MetaModel"), meta_model_name=data.get("meta_model_name", ""), meta_learner_class=data.get("meta_learner_class", ""), source_models=source_models, feature_columns=data.get("feature_columns", []), stacking_config=data.get("stacking_config", {}), selector_config=data.get("selector_config"), branch_context=data.get("branch_context"), use_proba=data.get("use_proba", False), n_folds=data.get("n_folds", 0), coverage_ratio=data.get("coverage_ratio", 1.0), artifact_id=data.get("artifact_id", ""), training_timestamp=data.get("training_timestamp", ""), # Phase 5: Classification support task_type=data.get("task_type", "regression"), n_classes=data.get("n_classes"), feature_to_model_mapping=data.get("feature_to_model_mapping"), )
[docs] def to_json(self) -> str: """Serialize to JSON string.""" return json.dumps(self.to_dict(), indent=2)
[docs] @classmethod def from_json(cls, json_str: str) -> "MetaModelArtifact": """Deserialize from JSON string.""" return cls.from_dict(json.loads(json_str))
[docs] def get_source_artifact_ids(self) -> List[str]: """Get ordered list of source model artifact IDs. Returns: List of artifact IDs in feature column order. """ return [ref.artifact_id for ref in self.source_models]
[docs] def get_source_by_index(self, index: int) -> Optional[SourceModelReference]: """Get source model reference by feature index. Args: index: Feature column index. Returns: SourceModelReference or None if index out of range. """ for ref in self.source_models: if ref.feature_index == index: return ref return None
[docs] def validate_feature_alignment(self) -> bool: """Validate that feature columns match source models. Returns: True if alignment is valid. """ if len(self.feature_columns) != len(self.source_models): return False # Check feature indices are sequential and match for idx, ref in enumerate(self.source_models): if ref.feature_index != idx: return False return True
[docs] def stacking_config_to_dict(config: StackingConfig) -> Dict[str, Any]: """Convert StackingConfig to serializable dictionary. Args: config: StackingConfig instance. Returns: Dictionary with string enum values. """ return { "coverage_strategy": config.coverage_strategy.value, "test_aggregation": config.test_aggregation.value, "branch_scope": config.branch_scope.value, "allow_no_cv": config.allow_no_cv, "min_coverage_ratio": config.min_coverage_ratio, }
[docs] def stacking_config_from_dict(data: Dict[str, Any]) -> StackingConfig: """Create StackingConfig from dictionary. Args: data: Dictionary with config values. Returns: StackingConfig instance. """ return StackingConfig( coverage_strategy=CoverageStrategy(data.get("coverage_strategy", "strict")), test_aggregation=TestAggregation(data.get("test_aggregation", "mean")), branch_scope=BranchScope(data.get("branch_scope", "current_only")), allow_no_cv=data.get("allow_no_cv", False), min_coverage_ratio=data.get("min_coverage_ratio", 1.0), )
[docs] class MetaModelSerializer: """Handles serialization and deserialization of meta-model artifacts. Provides methods to: - Build MetaModelArtifact from training context - Convert to/from MetaModelConfig for artifact registry - Validate artifact completeness Example: >>> serializer = MetaModelSerializer() >>> artifact = serializer.build_artifact( ... meta_operator=meta_model_op, ... source_models=selected_sources, ... reconstruction_result=result, ... context=execution_context ... ) >>> config = serializer.to_meta_model_config(artifact) """
[docs] def build_artifact( self, meta_operator: 'MetaModel', source_models: List['ModelCandidate'], reconstruction_result: Optional['ReconstructionResult'] = None, context: Optional['ExecutionContext'] = None, artifact_id: str = "", ) -> MetaModelArtifact: """Build MetaModelArtifact from training context. Args: meta_operator: The MetaModel operator being trained. source_models: List of selected source model candidates. reconstruction_result: Optional result from TrainingSetReconstructor. context: Optional execution context for branch info. artifact_id: The artifact ID for this meta-model. Returns: MetaModelArtifact ready for persistence. """ # Import here to avoid circular imports from nirs4all.operators.models.meta import MetaModel from nirs4all.operators.models.selection import ModelCandidate # Build source model references source_refs = [] feature_columns = [] # Build unique names for models, handling cross-branch duplicates # Count unique branches per model_name name_branch_pairs = set() for candidate in source_models: name_branch_pairs.add((candidate.model_name, candidate.branch_id)) branch_count_per_name: Dict[str, int] = {} for name, branch_id in name_branch_pairs: branch_count_per_name[name] = branch_count_per_name.get(name, 0) + 1 # Models needing branch suffix are those appearing in multiple branches needs_branch_suffix = { name for name, count in branch_count_per_name.items() if count > 1 } # Build unique source list with branch-aware deduplication seen_unique = set() unique_sources: List[Tuple['ModelCandidate', str]] = [] # (candidate, unique_name) for candidate in source_models: model_name = candidate.model_name branch_id = candidate.branch_id if model_name in needs_branch_suffix: if branch_id is not None: unique_name = f"{model_name}_br{branch_id}" else: unique_name = f"{model_name}_br_none" else: unique_name = model_name if unique_name not in seen_unique: seen_unique.add(unique_name) unique_sources.append((candidate, unique_name)) for idx, (candidate, unique_name) in enumerate(unique_sources): ref = SourceModelReference( model_name=unique_name, # Use unique name for persistence model_classname=candidate.model_classname, step_idx=candidate.step_idx, artifact_id=self._generate_source_artifact_id(candidate, context), feature_index=idx, fold_id=candidate.fold_id, branch_id=candidate.branch_id, branch_name=candidate.branch_name, val_score=candidate.val_score, metric=candidate.metric, ) source_refs.append(ref) feature_columns.append(f"{unique_name}_pred") # Get stacking config stacking_config_dict = stacking_config_to_dict(meta_operator.stacking_config) # Get selector config selector_config = None if meta_operator.selector is not None: selector_class = meta_operator.selector.__class__.__name__ selector_config = { "type": selector_class, "params": getattr(meta_operator.selector, 'get_params', lambda: {})() } elif meta_operator.source_models == "all": selector_config = {"type": "AllPreviousModelsSelector", "params": {}} elif isinstance(meta_operator.source_models, list): selector_config = { "type": "ExplicitModelSelector", "params": {"model_names": meta_operator.source_models} } # Get branch context branch_context = None if context is not None: branch_context = { "branch_id": getattr(context.selector, 'branch_id', None), "branch_name": getattr(context.selector, 'branch_name', None), "branch_path": getattr(context.selector, 'branch_path', None), } # Get reconstruction info and classification info n_folds = 0 coverage_ratio = 1.0 task_type = "regression" n_classes = None feature_to_model_mapping = None if reconstruction_result is not None: n_folds = reconstruction_result.n_folds coverage_ratio = reconstruction_result.coverage_ratio # Phase 5: Extract classification info from reconstruction result classification_info = getattr(reconstruction_result, 'classification_info', None) if classification_info is not None: task_type = classification_info.task_type.value n_classes = classification_info.n_classes # Extract feature to model mapping meta_feature_info = getattr(reconstruction_result, 'meta_feature_info', None) if meta_feature_info is not None: feature_to_model_mapping = meta_feature_info.feature_to_model # Update feature_columns from reconstruction result feature_columns = reconstruction_result.feature_names return MetaModelArtifact( meta_model_type="MetaModel", meta_model_name=meta_operator.name, meta_learner_class=type(meta_operator.model).__name__, source_models=source_refs, feature_columns=feature_columns, stacking_config=stacking_config_dict, selector_config=selector_config, branch_context=branch_context, use_proba=meta_operator.use_proba, n_folds=n_folds, coverage_ratio=coverage_ratio, artifact_id=artifact_id, # Phase 5: Classification support task_type=task_type, n_classes=n_classes, feature_to_model_mapping=feature_to_model_mapping, )
def _generate_source_artifact_id( self, candidate: 'ModelCandidate', context: Optional['ExecutionContext'] = None ) -> str: """Get or generate artifact ID for a source model. First tries to look up the actual artifact ID from the registry. Falls back to generating a V2-style ID for compatibility. Args: candidate: Source model candidate. context: Optional execution context for pipeline info. Returns: Artifact ID string (V3 format if found in registry, V2 fallback otherwise). """ # Get pipeline_id from context if available pipeline_id = "pipeline" runtime_context = None if context is not None: # Try to get from custom context runtime_context = context.custom.get('_runtime_context') if runtime_context is not None and hasattr(runtime_context, 'saver'): saver = runtime_context.saver if saver is not None: pipeline_id = getattr(saver, 'pipeline_id', 'pipeline') # Try to look up actual artifact ID from registry (V3 approach) if runtime_context is not None and runtime_context.artifact_registry is not None: registry = runtime_context.artifact_registry # Build branch_path from candidate branch_path = [candidate.branch_id] if candidate.branch_id is not None else [] # Parse fold_id (may be "avg", "w_avg", or numeric) fold_id = None if candidate.fold_id is not None: if isinstance(candidate.fold_id, int): fold_id = candidate.fold_id elif isinstance(candidate.fold_id, str) and candidate.fold_id.isdigit(): fold_id = int(candidate.fold_id) # "avg" and "w_avg" remain as None fold_id # Look up artifacts for this step step_artifacts = registry.get_artifacts_for_step( pipeline_id=pipeline_id, step_index=candidate.step_idx ) for record in step_artifacts: # Match by branch and fold first record_branch = record.branch_path or [] if record_branch != branch_path or record.fold_id != fold_id: continue # Match by class name OR custom_name (model_name) # Custom name match handles MetaModel case where: # - record.class_name = "Ridge" (underlying meta-learner) # - candidate.model_classname = "MetaModel" (operator class) # - record.custom_name = "Ridge_MetaModel" (model_name) # - candidate.model_name = "Ridge_MetaModel" (from predictions) class_match = record.class_name == candidate.model_classname name_match = record.custom_name and record.custom_name == candidate.model_name if class_match or name_match: return record.artifact_id # Fallback: generate V2-style ID for compatibility branch_path = [] if candidate.branch_id is not None: branch_path = [candidate.branch_id] # Simple format: pipeline_id:step:fold branch_str = ':'.join(str(b) for b in branch_path) if branch_path else "" fold_str = str(candidate.fold_id) if candidate.fold_id else "all" if branch_str: return f"{pipeline_id}:{branch_str}:{candidate.step_idx}:{fold_str}" else: return f"{pipeline_id}:{candidate.step_idx}:{fold_str}"
[docs] def to_meta_model_config(self, artifact: MetaModelArtifact) -> 'MetaModelConfig': """Convert MetaModelArtifact to MetaModelConfig for registry. The ArtifactRegistry uses MetaModelConfig to track source model dependencies. This method creates the appropriate config. Args: artifact: MetaModelArtifact to convert. Returns: MetaModelConfig for artifact registry. """ from nirs4all.pipeline.storage.artifacts.types import MetaModelConfig source_models = [ { "artifact_id": ref.artifact_id, "feature_index": ref.feature_index, "model_name": ref.model_name, } for ref in artifact.source_models ] return MetaModelConfig( source_models=source_models, feature_columns=artifact.feature_columns )
[docs] def validate_artifact(self, artifact: MetaModelArtifact) -> List[str]: """Validate artifact completeness and consistency. Args: artifact: MetaModelArtifact to validate. Returns: List of validation error messages (empty if valid). """ errors = [] # Check required fields if not artifact.meta_model_name: errors.append("Missing meta_model_name") if not artifact.meta_learner_class: errors.append("Missing meta_learner_class") if not artifact.source_models: errors.append("No source models defined") if not artifact.feature_columns: errors.append("No feature columns defined") # Check feature alignment if len(artifact.feature_columns) != len(artifact.source_models): errors.append( f"Feature column count ({len(artifact.feature_columns)}) " f"doesn't match source model count ({len(artifact.source_models)})" ) # Check feature indices are sequential if artifact.source_models: indices = [ref.feature_index for ref in artifact.source_models] expected = list(range(len(indices))) if sorted(indices) != expected: errors.append(f"Non-sequential feature indices: {indices}") # Check all source models have artifact IDs for ref in artifact.source_models: if not ref.artifact_id: errors.append(f"Source model {ref.model_name} missing artifact_id") return errors