"""
Model Identifier Generator - Generate consistent model identifiers
This component centralizes all model naming and identification logic.
Extracted from launch_training() lines 329-345 to improve maintainability.
Generates:
- classname: from model config or instance.__class__.__name__
- name: custom name from config or classname
- model_id: name + operation counter (unique for run)
- display_name: model_id with fold suffix if applicable
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from nirs4all.pipeline.runner import PipelineRunner
from nirs4all.pipeline.config.context import ExecutionContext
[docs]
@dataclass
class ModelIdentifiers:
"""Container for all model identifiers."""
classname: str # Class name of the model (e.g., "RandomForestRegressor")
name: str # User-provided name or classname
model_id: str # name + operation counter (e.g., "MyModel_10")
display_name: str # model_id with fold suffix (e.g., "MyModel_10_fold0")
operation_counter: int # Operation counter from runner
step_id: int # Pipeline step index
fold_idx: Optional[int] # Fold index if applicable
[docs]
class ModelIdentifierGenerator:
"""Generates consistent model identifiers for training and persistence.
This component extracts and centralizes all the naming logic that was
previously scattered in launch_training().
Example:
>>> generator = ModelIdentifierGenerator()
>>> identifiers = generator.generate(
... model_config={'name': 'MyPLS', 'class': 'sklearn.cross_decomposition.PLSRegression'},
... runner=runner,
... context={'step_id': 5},
... fold_idx=0
... )
>>> identifiers.model_id
'MyPLS_10'
>>> identifiers.display_name
'MyPLS_10_fold0'
"""
def __init__(self, helper=None):
"""Initialize identifier generator.
Args:
helper: Deprecated parameter, kept for backwards compatibility.
"""
# Helper is deprecated - methods are now implemented directly in this class
def _get_model_instance_from_config(self, model_config: Dict[str, Any]) -> Any:
"""Helper to extract model instance from various config formats.
Args:
model_config: Model configuration dictionary.
Returns:
Model instance or None.
"""
if isinstance(model_config, dict):
# Direct model_instance
if 'model_instance' in model_config:
return model_config['model_instance']
# Nested model structure
elif 'model' in model_config:
model_obj = model_config['model']
if isinstance(model_obj, dict):
if 'model' in model_obj:
return model_obj['model']
else:
return model_obj
else:
return model_obj
else:
return model_config
return None
def _get_model_class_name(self, model: Any) -> str:
"""Get the class name of a model.
Args:
model: Model object, class, function, or string representation.
Returns:
str: Class or function name.
"""
import inspect
# Handle dict wrapper from deserialize_component for @framework functions
if isinstance(model, dict) and 'func' in model:
func = model['func']
if callable(func) and hasattr(func, '__name__'):
return func.__name__
if inspect.isclass(model):
return f"{model.__qualname__}"
if inspect.isfunction(model) or inspect.isbuiltin(model):
return f"{model.__name__}"
# Handle string representation of functions/classes from deserialization
if isinstance(model, str):
if model.startswith("<function ") and " at 0x" in model:
# Extract function name
return model.split("<function ")[1].split(" at ")[0]
elif model.startswith("<class '") and "' at 0x" in model:
# Extract class name
return model.split("<class '")[1].split("' at ")[0].split(".")[-1]
else:
return str(type(model).__name__)
[docs]
def generate(
self,
model_config: Dict[str, Any],
runner: 'PipelineRunner',
context: 'ExecutionContext',
fold_idx: Optional[int] = None
) -> ModelIdentifiers:
"""Generate all model identifiers from configuration and context.
Args:
model_config: Model configuration dictionary
runner: Pipeline runner for operation counter
context: Execution context with step_number
fold_idx: Optional fold index for cross-validation
Returns:
ModelIdentifiers: Container with all generated identifiers
"""
# Extract base information
classname = self.extract_classname_from_config(model_config)
name = self.extract_core_name(model_config)
# Get operation counter and step info
operation_counter = runner.next_op()
# Use step_number (int) not step_id (str) for the numeric step index
step_id = context.state.step_number
# Build model_id and display_name
model_id = f"{name}_{operation_counter}"
display_name = model_id
if fold_idx is not None:
display_name += f"_fold{fold_idx}"
return ModelIdentifiers(
classname=classname,
name=name,
model_id=model_id,
display_name=display_name,
operation_counter=operation_counter,
step_id=step_id,
fold_idx=fold_idx
)
[docs]
def generate_binary_key(
self,
model_id: str,
fold_idx: Optional[int] = None
) -> str:
"""Generate the binary storage key for a model.
Args:
model_id: Base model identifier (e.g., "MyModel_10")
fold_idx: Optional fold index
Returns:
Binary key string (e.g., "MyModel_10" or "MyModel_10_fold0")
"""
if fold_idx is not None:
return f"{model_id}_fold{fold_idx}"
return model_id