Source code for nirs4all.visualization.pipeline_diagram

"""
Pipeline Diagram - DAG visualization for pipeline execution structure.

This module provides visualization tools for displaying the complete
pipeline structure as a directed acyclic graph (DAG).

The diagram shows:
- All pipeline steps with operator names
- Dataset shape at each step (samples × processings × features)
- Branching and merging points
- Model training steps
- Cross-validation splitters

Shape notation: S×P×F
- S = samples
- P = processings (preprocessing views)
- F = features (wavelengths/columns)

Example:
    >>> from nirs4all.visualization.pipeline_diagram import PipelineDiagram
    >>> diagram = PipelineDiagram(pipeline_steps, predictions)
    >>> fig = diagram.render()
    >>> fig.savefig('pipeline_diagram.png')
"""

from typing import Any, Dict, List, Optional, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.patches import FancyBboxPatch, Rectangle, FancyArrowPatch
import matplotlib.patches as mpatches
import numpy as np


[docs] @dataclass class PipelineNode: """Represents a node in the pipeline DAG. Attributes: id: Unique node identifier step_index: Pipeline step index (1-based) label: Display label for the node node_type: Type of node (preprocessing, model, splitter, branch, merge, etc.) shape_before: Dataset shape before this step (samples, processings, features) shape_after: Dataset shape after this step input_layout_shape: 2D layout shape before step (samples, features) output_layout_shape: 2D layout shape after step (samples, features) features_shape: List of 3D per-source shapes (samples, processings, features) branch_id: Branch ID if inside a branch (None if not) branch_name: Branch name if inside a branch substep_index: Index within a branch's substeps parent_ids: List of parent node IDs children_ids: List of child node IDs duration_ms: Execution duration in milliseconds (from trace) metadata: Additional node metadata """ id: str step_index: int label: str node_type: str = "preprocessing" shape_before: Optional[Tuple[int, int, int]] = None shape_after: Optional[Tuple[int, int, int]] = None input_layout_shape: Optional[Tuple[int, int]] = None output_layout_shape: Optional[Tuple[int, int]] = None features_shape: Optional[List[Tuple[int, int, int]]] = None branch_id: Optional[int] = None branch_name: str = "" substep_index: Optional[int] = None parent_ids: List[str] = field(default_factory=list) children_ids: List[str] = field(default_factory=list) duration_ms: float = 0.0 metadata: Dict[str, Any] = field(default_factory=dict)
[docs] class PipelineDiagram: """Create DAG visualization for pipeline execution structure. Renders a visual diagram showing the complete pipeline topology, including all steps, shapes, branches, and models. Attributes: pipeline_steps: List of pipeline step definitions predictions: Optional Predictions object with execution data execution_trace: Optional ExecutionTrace with actual runtime shapes config: Optional dict for customization """ # Professional color palette (Fill, Border) NODE_STYLES = { 'preprocessing': ('#E3F2FD', '#1976D2'), # Blue 50 / 700 'feature_augmentation': ('#E0F2F1', '#00796B'), # Teal 50 / 700 'sample_augmentation': ('#E8F5E9', '#388E3C'), # Green 50 / 700 'concat_transform': ('#F3E5F5', '#7B1FA2'), # Purple 50 / 700 'y_processing': ('#FFF8E1', '#FFA000'), # Amber 50 / 700 'splitter': ('#F3E5F5', '#7B1FA2'), # Purple 50 / 700 'branch': ('#E0F2F1', '#00796B'), # Teal 50 / 700 'merge': ('#E0F2F1', '#00796B'), # Teal 50 / 700 'source_branch': ('#E0F2F1', '#00796B'), # Teal 50 / 700 'merge_sources': ('#E0F2F1', '#00796B'), # Teal 50 / 700 'model': ('#FFEBEE', '#D32F2F'), # Red 50 / 700 'input': ('#FAFAFA', '#616161'), # Gray 50 / 700 'output': ('#FAFAFA', '#616161'), # Gray 50 / 700 'default': ('#ECEFF1', '#455A64'), # Blue Grey 50 / 700 } def __init__( self, pipeline_steps: Optional[List[Any]] = None, predictions: Any = None, execution_trace: Any = None, config: Optional[Dict[str, Any]] = None ): """Initialize PipelineDiagram. Args: pipeline_steps: List of pipeline step definitions predictions: Optional Predictions object with execution data execution_trace: Optional ExecutionTrace with runtime shapes config: Optional dict for customization: - figsize: Tuple for figure size - fontsize: Base font size - node_width: Width of nodes - node_height: Height of nodes - show_shapes: Whether to show shape info - compact: Use compact node labels """ self.pipeline_steps = pipeline_steps or [] self.predictions = predictions self.execution_trace = execution_trace self.config = config or {} # Default configuration - optimized for publication self._figsize = self.config.get('figsize', (12, 8)) self._fontsize = self.config.get('fontsize', 9) self._node_width = self.config.get('node_width', 2.5) self._node_height = self.config.get('node_height', 0.8) self._show_shapes = self.config.get('show_shapes', True) self._compact = self.config.get('compact', False) # Build the DAG self.nodes: Dict[str, PipelineNode] = {} self.edges: List[Tuple[str, str]] = []
[docs] def render( self, show_shapes: Optional[bool] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, initial_shape: Optional[Tuple[int, int, int]] = None ) -> Figure: """Render the pipeline diagram. Args: show_shapes: Override config's show_shapes setting figsize: Override figure size title: Optional title for the diagram initial_shape: Initial dataset shape (samples, processings, features) Returns: matplotlib Figure object """ # Apply overrides effective_show_shapes = show_shapes if show_shapes is not None else self._show_shapes effective_figsize = figsize if figsize is not None else self._figsize # Only build DAG if not already built (e.g., from_trace already built it) if not self.nodes: self._build_dag(initial_shape=initial_shape) if not self.nodes: # No steps - show simple message fig, ax = plt.subplots(figsize=(6, 2)) ax.text(0.5, 0.5, 'No pipeline steps to visualize', ha='center', va='center', fontsize=14) ax.axis('off') return fig # Calculate layout layout = self._compute_layout() # Create figure fig, ax = plt.subplots(figsize=effective_figsize) # Draw the diagram self._draw_edges(ax, layout) self._draw_nodes(ax, layout, effective_show_shapes) # Configure axes ax.set_aspect('equal') ax.axis('off') # Set title if title is None: n_steps = len([n for n in self.nodes.values() if n.node_type != 'input']) title = f"Pipeline Structure ({n_steps} steps)" ax.set_title(title, fontsize=self._fontsize + 4, fontweight='bold', pad=20, color='#263238') # Adjust limits x_min, x_max, y_min, y_max = self._get_bounds(layout) padding = 1.0 ax.set_xlim(x_min - padding, x_max + padding) ax.set_ylim(y_min - padding, y_max + padding) # Add legend self._add_legend(ax) plt.tight_layout() return fig
[docs] @classmethod def from_trace( cls, execution_trace: Any, predictions: Any = None, config: Optional[Dict[str, Any]] = None ) -> 'PipelineDiagram': """Create a PipelineDiagram from an ExecutionTrace. This builds the diagram using actual runtime data including measured shapes at each step. Args: execution_trace: ExecutionTrace object from pipeline execution predictions: Optional Predictions object to enrich nodes with scores config: Optional configuration dict Returns: PipelineDiagram instance ready for rendering Example: >>> from nirs4all.visualization import PipelineDiagram >>> diagram = PipelineDiagram.from_trace(trace) >>> fig = diagram.render(title="Execution Trace") """ diagram = cls( execution_trace=execution_trace, predictions=predictions, config=config ) diagram._build_dag_from_trace() return diagram
def _build_dag_from_trace(self) -> None: """Build the DAG from an ExecutionTrace object with actual shapes.""" if not self.execution_trace: return self.nodes.clear() self.edges.clear() steps = getattr(self.execution_trace, 'steps', []) if not steps: return # Pre-process: collect shapes from steps for inheritance step_shapes = {} # step_index -> (input_features, output_features) for step in steps: idx = getattr(step, 'step_index', 0) input_f = getattr(step, 'input_features_shape', None) output_f = getattr(step, 'output_features_shape', None) if input_f or output_f: step_shapes[idx] = (input_f, output_f) # Create input node from first step's input shape first_step = steps[0] if steps else None input_layout = getattr(first_step, 'input_shape', None) if first_step else None input_features = getattr(first_step, 'input_features_shape', None) if first_step else None input_node = PipelineNode( id="input", step_index=0, label="Dataset", node_type="input", output_layout_shape=input_layout, features_shape=input_features, ) self.nodes["input"] = input_node # Track edges by step current_node_ids = ["input"] branch_stacks: Dict[tuple, List[str]] = {} # branch_path -> node_ids in_branch_mode = False # Track if we're inside a branch in_source_branch_mode = False # Track if we're inside source_branch (before merge) source_branch_node_ids: List[str] = [] # Track source branch node IDs n_sources = 0 # Number of sources in source_branch mode last_pre_branch_node = "input" # Track the node before entering branches last_known_features_shape = input_features # Track last known shape for inheritance for step_i, step in enumerate(steps): step_index = getattr(step, 'step_index', 0) operator_class = getattr(step, 'operator_class', '') or '' operator_type = getattr(step, 'operator_type', '') or '' branch_path = tuple(getattr(step, 'branch_path', []) or []) branch_name = getattr(step, 'branch_name', '') or '' duration_ms = getattr(step, 'duration_ms', 0.0) substep_idx = getattr(step, 'substep_index', None) operator_config = getattr(step, 'operator_config', {}) or {} # Get shapes from trace input_layout = getattr(step, 'input_shape', None) output_layout = getattr(step, 'output_shape', None) input_features = getattr(step, 'input_features_shape', None) output_features = getattr(step, 'output_features_shape', None) # Check for source_branch and expand it is_source_branch = operator_config.get('source_branch', False) if is_source_branch: # Get number of sources from operator_config n_sources = operator_config.get('n_sources', 0) # Create expanded nodes for source_branch self._expand_source_branch( step, step_index, steps, step_i, current_node_ids, branch_stacks, last_known_features_shape ) # Update current_node_ids to point to the source branch nodes source_branch_ids = [nid for nid in self.nodes.keys() if nid.startswith(f"step_{step_index}_src")] source_branch_ids.sort() # Ensure consistent ordering if source_branch_ids: current_node_ids = source_branch_ids source_branch_node_ids = source_branch_ids.copy() in_branch_mode = True in_source_branch_mode = True continue # Check if this step runs on multiple sources (inside source_branch mode) # and should be expanded to show one node per source artifacts = getattr(step, 'artifacts', None) by_source = {} if artifacts: if hasattr(artifacts, 'by_source'): by_source = artifacts.by_source or {} elif isinstance(artifacts, dict): by_source = artifacts.get('by_source', {}) op_type_lower = operator_type.lower() if operator_type else '' # Check if this is a merge step that exits source_branch mode if in_source_branch_mode and op_type_lower == 'merge' and not branch_path: # This is the source branch merge - exit source_branch mode self._create_merge_node_for_source_branch( step, step_index, operator_class, current_node_ids, input_layout, output_layout, input_features, output_features ) # Exit source_branch mode in_source_branch_mode = False source_branch_node_ids = [] current_node_ids = [f"step_{step_index}"] in_branch_mode = False branch_stacks.clear() continue # Expand step if in source_branch mode and step runs on multiple sources if in_source_branch_mode and len(by_source) > 1 and not branch_path: # Expand this step into per-source nodes self._expand_step_for_sources( step, step_index, operator_class, operator_type, by_source, current_node_ids, input_features, output_features ) # Update current_node_ids to point to the expanded nodes expanded_ids = [nid for nid in self.nodes.keys() if nid.startswith(f"step_{step_index}_src")] expanded_ids.sort() if expanded_ids: current_node_ids = expanded_ids source_branch_node_ids = expanded_ids.copy() continue # Fallback for features shape if missing but layout exists if output_features is None and output_layout is not None: s, f = output_layout output_features = [(s, 1, f)] # If output_features still missing, try input_features as fallback if output_features is None and input_features is not None: output_features = input_features # Derive model output shape from predictions (model output is prediction, not feature concat) node_type_tmp = self._classify_operator_type(operator_type, operator_class) derived_shape = None if node_type_tmp == 'model' and self.predictions is not None: derived_shape = self._derive_model_input_shape_from_predictions( step_index=step_index, branch_path=branch_path, branch_name=branch_name, operator_class=operator_class, operator_config=operator_config ) # For model nodes, always prefer derived shape (represents prediction output, not input features) if derived_shape and node_type_tmp == 'model': output_features = [derived_shape] output_layout = (derived_shape[0], derived_shape[2]) # For branch substeps with no shape, try to inherit from step shapes if output_features is None and branch_path: # Try to get shape from next non-branch step for future_step in steps[step_i + 1:]: future_bp = tuple(getattr(future_step, 'branch_path', []) or []) if not future_bp: # Non-branch step future_input = getattr(future_step, 'input_features_shape', None) if future_input: output_features = future_input break # If still no shape, use last known if output_features is None and last_known_features_shape: output_features = last_known_features_shape # Same fallback for input shapes if input_features is None and input_layout is not None: s, f = input_layout input_features = [(s, 1, f)] # Update last known shape if we have output if output_features and not branch_path: last_known_features_shape = output_features # Extract metadata including score and custom name metadata = {} if operator_config and 'n_splits' in operator_config: metadata['n_splits'] = operator_config['n_splits'] score = getattr(step, 'score', None) if score is None and hasattr(step, 'metadata'): score = step.metadata.get('score') if score is not None: metadata['best_score'] = score # Try to get custom name from artifacts metadata artifacts = getattr(step, 'artifacts', None) if artifacts and hasattr(artifacts, 'metadata'): art_meta = artifacts.metadata if hasattr(artifacts, 'metadata') else {} if isinstance(art_meta, dict): custom_name = art_meta.get('custom_name', '') if custom_name: metadata['custom_name'] = custom_name # Enrich model nodes with best score from predictions if available if metadata.get('best_score') is None and self.predictions is not None: model_name = metadata.get('custom_name') if not model_name and isinstance(operator_config, dict): model_name = operator_config.get('name') or operator_config.get('model_name') best_score = self._get_best_score_from_predictions( step_index=step_index, branch_path=branch_path, branch_name=branch_name, model_name=model_name, operator_class=operator_class ) if best_score is not None: metadata['best_score'] = best_score # Determine node type from operator info node_type = node_type_tmp # Format the label from operator info label = self._format_trace_label(operator_class, operator_type, step) # Create node ID if branch_path: node_id = f"step_{step_index}_b{'_'.join(map(str, branch_path))}" else: node_id = f"step_{step_index}" # Handle substep index for unique node IDs if substep_idx is not None: node_id += f"_s{substep_idx}" # Determine parent nodes based on branch path and operator type op_type_lower = operator_type.lower() if operator_type else '' if branch_path and substep_idx is not None: # Branch substep - chain within the same branch if substep_idx == 0: # First substep - connect to the pre-branch node (current main path) parent_nodes = current_node_ids.copy() if current_node_ids else [last_pre_branch_node] else: # Chain to previous substep in same branch prev_substep_id = f"step_{step_index}_b{'_'.join(map(str, branch_path))}_s{substep_idx - 1}" if prev_substep_id in self.nodes: parent_nodes = [prev_substep_id] else: parent_nodes = current_node_ids.copy() if current_node_ids else [last_pre_branch_node] in_branch_mode = True elif branch_path: # Branch step without substep index (e.g., post-branch steps like splitter) bp_tuple = branch_path parent_found = False while len(bp_tuple) > 0: if bp_tuple in branch_stacks: parent_nodes = branch_stacks[bp_tuple] parent_found = True break bp_tuple = bp_tuple[:-1] if not parent_found: branch_id = branch_path[0] if branch_path else 0 matching_leaves = [] for bpath, node_ids in branch_stacks.items(): if bpath and bpath[0] == branch_id: matching_leaves.extend(node_ids) parent_nodes = matching_leaves if matching_leaves else current_node_ids in_branch_mode = True elif op_type_lower == 'merge' and in_branch_mode and branch_stacks: # Merge step exiting branch mode - connect to the LEAF nodes of each branch # Group by top-level branch ID and take the most recent (shallowest path length for same branch) branch_leaves_by_id: Dict[int, str] = {} for bpath, node_ids in branch_stacks.items(): if not bpath: continue top_branch_id = bpath[0] # Prefer shorter paths (they are more recent, like MetaModel at bp=[0] vs substep at bp=[0,0]) current_depth = len(bpath) if node_ids: latest_node = node_ids[-1] if top_branch_id not in branch_leaves_by_id: branch_leaves_by_id[top_branch_id] = (current_depth, latest_node) else: existing_depth, _ = branch_leaves_by_id[top_branch_id] if current_depth < existing_depth: # Shorter path = more recent (e.g., MetaModel at [0] vs substep at [0,0]) branch_leaves_by_id[top_branch_id] = (current_depth, latest_node) all_branch_leaves = [node_id for _, node_id in branch_leaves_by_id.values()] parent_nodes = all_branch_leaves if all_branch_leaves else current_node_ids in_branch_mode = False branch_stacks.clear() elif op_type_lower == 'merge' and in_branch_mode: parent_nodes = current_node_ids in_branch_mode = False elif op_type_lower == 'branch': parent_nodes = current_node_ids last_pre_branch_node = current_node_ids[0] if current_node_ids else "input" in_branch_mode = True else: parent_nodes = current_node_ids if not in_branch_mode: last_pre_branch_node = current_node_ids[0] if current_node_ids else "input" # Create the node node = PipelineNode( id=node_id, step_index=step_index, label=label, node_type=node_type, input_layout_shape=input_layout, output_layout_shape=output_layout, features_shape=output_features if output_features else input_features, branch_id=branch_path[-1] if branch_path else None, branch_name=branch_name, substep_index=substep_idx, parent_ids=list(parent_nodes) if parent_nodes else [], duration_ms=duration_ms, metadata=metadata, ) self.nodes[node_id] = node # Add edges from parents for parent_id in (parent_nodes or []): if parent_id in self.nodes: self.edges.append((parent_id, node_id)) # Update tracking if branch_path: branch_stacks[branch_path] = [node_id] else: current_node_ids = [node_id] def _get_best_score_from_predictions( self, step_index: int, branch_path: List[int], branch_name: str, model_name: Optional[str], operator_class: str ) -> Optional[float]: """Return best score (prefer test of best val; fallback to best available).""" if self.predictions is None: return None filter_kwargs = {'step_idx': step_index, 'load_arrays': False} if branch_path: filter_kwargs['branch_id'] = branch_path[-1] if branch_name: filter_kwargs['branch_name'] = branch_name try: preds = self.predictions.filter_predictions(**filter_kwargs) except Exception: preds = [] # Fallback: if nothing matched with branch filters, try without them if not preds: try: slim_kwargs = {'step_idx': step_index, 'load_arrays': False} preds = self.predictions.filter_predictions(**slim_kwargs) except Exception: preds = [] if model_name: preds = [p for p in preds if p.get('model_name') == model_name] # If nothing yet, try class match if not preds and operator_class: preds = [p for p in preds if p.get('model_classname') == operator_class] # Last resort: drop model filters and keep step match only if not preds: try: preds = self.predictions.filter_predictions(step_idx=step_index, load_arrays=False) except Exception: preds = [] if not preds: return None lower_tokens = ( 'rmse', 'mae', 'mse', 'mape', 'msle', 'logloss', 'log_loss', 'loss', 'error', 'hinge' ) def is_higher_better(metric_name: str) -> bool: metric_lower = (metric_name or '').lower() return not any(tok in metric_lower for tok in lower_tokens) # Rank entries by val_score when present; otherwise by test_score/train_score def primary_score(pred: Dict[str, Any]) -> Optional[float]: if pred.get('val_score') is not None: return pred.get('val_score') if pred.get('test_score') is not None: return pred.get('test_score') return pred.get('train_score') best_entry: Optional[Dict[str, Any]] = None higher_is_better: Optional[bool] = None for pred in preds: score_candidate = primary_score(pred) if score_candidate is None: continue if higher_is_better is None: higher_is_better = is_higher_better(pred.get('metric', '')) if best_entry is None: best_entry = pred continue current_best = primary_score(best_entry) if current_best is None: best_entry = pred continue if higher_is_better: if score_candidate > current_best: best_entry = pred else: if score_candidate < current_best: best_entry = pred if best_entry is None: return None # Display priority: test -> val -> train for key in ('test_score', 'val_score', 'train_score'): if best_entry.get(key) is not None: return best_entry.get(key) return None def _get_total_samples_from_trace(self) -> Optional[int]: """Get total sample count from execution trace. Looks for sample count after sample_augmentation or at the splitter step, which reflects the true dataset size including augmented samples. """ if not self.execution_trace: return None steps = getattr(self.execution_trace, 'steps', []) total_samples = None # Look for sample count in order of preference: # 1. splitter step (most reliable - shows full dataset before CV split) # 2. sample_augmentation output (if present) # 3. Any step with features shape for step in steps: op_type = getattr(step, 'operator_type', '') or '' output_features = getattr(step, 'output_features_shape', None) if op_type.lower() == 'splitter' and output_features: # Splitter shows the full dataset size total_samples = output_features[0][0] break # Fallback: look for sample_augmentation or any step with features if total_samples is None: for step in steps: op_type = getattr(step, 'operator_type', '') or '' output_features = getattr(step, 'output_features_shape', None) if op_type.lower() == 'sample_augmentation' and output_features: total_samples = output_features[0][0] break # Final fallback: first step with output features if total_samples is None: for step in steps: output_features = getattr(step, 'output_features_shape', None) if output_features and output_features[0]: total_samples = output_features[0][0] break return total_samples def _derive_model_input_shape_from_predictions( self, step_index: int, branch_path: List[int], branch_name: str, operator_class: str, operator_config: Dict[str, Any], ) -> Optional[Tuple[int, int, int]]: """Infer model INPUT 3D shape (samples, processings, features) from predictions metadata. For models, we want to show the input shape (what the model receives), not the output shape (predictions). The n_features field in predictions stores the input feature count. Sample count is derived from the execution trace to include augmented samples. """ if self.predictions is None: return None filter_kwargs = {'step_idx': step_index, 'load_arrays': False} if branch_path: filter_kwargs['branch_id'] = branch_path[-1] if branch_name: filter_kwargs['branch_name'] = branch_name try: preds = self.predictions.filter_predictions(**filter_kwargs) except Exception: preds = [] if not preds: try: preds = self.predictions.filter_predictions(step_idx=step_index, load_arrays=False) except Exception: preds = [] if not preds: return None # Get feature count from predictions metadata feature_counts: List[int] = [] for pred in preds: n_features = pred.get('n_features') if n_features not in (None, 0): feature_counts.append(int(n_features)) # Use minimum n_features (merged features are typically smaller than concat) feature_count = min(feature_counts) if feature_counts else None if feature_count is None: return None # Get total sample count from trace (includes augmented samples) sample_count = self._get_total_samples_from_trace() # Fallback: calculate from predictions partitions if trace not available if sample_count is None: partition_counts: Dict[str, int] = {} for pred in preds: n_samples = int(pred.get('n_samples') or 0) partition = str(pred.get('partition') or '').lower() if partition and n_samples > 0: partition_counts[partition] = max(partition_counts.get(partition, 0), n_samples) if partition_counts: if 'train' in partition_counts: sample_count = partition_counts['train'] if 'test' in partition_counts: sample_count += partition_counts['test'] elif 'predict' in partition_counts: sample_count += partition_counts['predict'] else: sample_count = sum(partition_counts.values()) if sample_count is None or sample_count == 0: return None # Return as 3D shape (samples, 1 processing, features) return (sample_count, 1, int(feature_count)) def _expand_source_branch( self, step: Any, step_index: int, all_steps: List[Any], step_i: int, current_node_ids: List[str], branch_stacks: Dict[tuple, List[str]], last_known_shape: Optional[List[Tuple[int, int, int]]] ) -> None: """Expand a source_branch step into per-source nodes. Args: step: The source_branch step step_index: Step index all_steps: All execution steps step_i: Index of current step in all_steps current_node_ids: Current parent node IDs branch_stacks: Branch tracking dict last_known_shape: Last known features shape """ artifacts = getattr(step, 'artifacts', None) if not artifacts: return # Get by_chain to understand what transformers are in each source by_chain = {} if hasattr(artifacts, 'by_chain'): by_chain = artifacts.by_chain elif isinstance(artifacts, dict): by_chain = artifacts.get('by_chain', {}) if not by_chain: return # Parse chain keys to group by source # Format: s5.MinMaxScaler[src=0,sub=0], s5.PCA[src=2,sub=8] source_ops: Dict[int, List[Tuple[int, str]]] = {} # source_id -> [(substep, class_name), ...] for chain_key in by_chain.keys(): # Extract source and class from chain key if '[src=' in chain_key: parts = chain_key.split('[') class_part = parts[0].split('.')[-1] # e.g., "MinMaxScaler" src_part = parts[1] if len(parts) > 1 else '' if 'src=' in src_part: src_str = src_part.split('src=')[1].split(',')[0].split(']')[0] try: src_id = int(src_str) except ValueError: continue sub_str = src_part.split('sub=')[1].split(']')[0] if 'sub=' in src_part else '0' try: sub_id = int(sub_str) except ValueError: sub_id = 0 if src_id not in source_ops: source_ops[src_id] = [] source_ops[src_id].append((sub_id, class_part)) if not source_ops: return # Get shapes from next step if available next_step_shapes = None if step_i + 1 < len(all_steps): next_step = all_steps[step_i + 1] next_step_shapes = getattr(next_step, 'input_features_shape', None) # Create a node for each source branch for src_id in sorted(source_ops.keys()): ops_list = source_ops[src_id] # Sort by substep and get unique class names ops_list.sort(key=lambda x: x[0]) seen_classes = [] for _, cls_name in ops_list: if cls_name not in seen_classes: seen_classes.append(cls_name) # Create label if len(seen_classes) <= 2: label = f"Src{src_id}: {' → '.join(seen_classes)}" else: label = f"Src{src_id}: {seen_classes[0]}...{seen_classes[-1]}" # Get shape for this source src_shape = None if next_step_shapes and src_id < len(next_step_shapes): src_shape = [next_step_shapes[src_id]] elif last_known_shape and src_id < len(last_known_shape): src_shape = [last_known_shape[src_id]] node_id = f"step_{step_index}_src{src_id}" node = PipelineNode( id=node_id, step_index=step_index, label=label, node_type='source_branch', branch_id=src_id, branch_name=f"source_{src_id}", features_shape=src_shape, parent_ids=current_node_ids.copy(), ) self.nodes[node_id] = node # Add edges from parents for parent_id in current_node_ids: if parent_id in self.nodes: self.edges.append((parent_id, node_id)) # Track in branch stacks branch_stacks[(src_id,)] = [node_id] def _expand_step_for_sources( self, step: Any, step_index: int, operator_class: str, operator_type: str, by_source: Dict[int, List[str]], current_node_ids: List[str], input_features: Optional[List[Tuple[int, int, int]]], output_features: Optional[List[Tuple[int, int, int]]] ) -> None: """Expand a step that runs on multiple sources into per-source nodes. Args: step: The execution step step_index: Step index operator_class: Class name operator_type: Operator type by_source: Artifacts by source current_node_ids: Current parent node IDs (should be source branch nodes) input_features: Input feature shapes output_features: Output feature shapes """ # Create a node for each source for src_id in sorted(by_source.keys()): # Get shape for this source src_shape = None if output_features and src_id < len(output_features): src_shape = [output_features[src_id]] elif input_features and src_id < len(input_features): src_shape = [input_features[src_id]] # Find the parent node for this source parent_id = None for pid in current_node_ids: if f"_src{src_id}" in pid: parent_id = pid break if parent_id is None and current_node_ids: # Fallback: use the src_id-th parent if available if src_id < len(current_node_ids): parent_id = current_node_ids[src_id] else: parent_id = current_node_ids[0] node_id = f"step_{step_index}_src{src_id}" label = f"Src{src_id}: {operator_class}" node = PipelineNode( id=node_id, step_index=step_index, label=label, node_type=self._classify_operator_type(operator_type, operator_class), branch_id=src_id, branch_name=f"source_{src_id}", features_shape=src_shape, parent_ids=[parent_id] if parent_id else [], ) self.nodes[node_id] = node # Add edge from parent if parent_id and parent_id in self.nodes: self.edges.append((parent_id, node_id)) def _create_merge_node_for_source_branch( self, step: Any, step_index: int, operator_class: str, current_node_ids: List[str], input_layout: Optional[Tuple[int, int]], output_layout: Optional[Tuple[int, int]], input_features: Optional[List[Tuple[int, int, int]]], output_features: Optional[List[Tuple[int, int, int]]] ) -> None: """Create a merge node that collects from all source branches. Args: step: The execution step step_index: Step index operator_class: Class name current_node_ids: Current parent node IDs (source branch nodes) input_layout: Input layout shape output_layout: Output layout shape input_features: Input feature shapes output_features: Output feature shapes """ node_id = f"step_{step_index}" node = PipelineNode( id=node_id, step_index=step_index, label="Merge Sources", node_type='merge_sources', input_layout_shape=input_layout, output_layout_shape=output_layout, features_shape=output_features if output_features else input_features, parent_ids=current_node_ids.copy(), ) self.nodes[node_id] = node # Add edges from all source branch nodes for parent_id in current_node_ids: if parent_id in self.nodes: self.edges.append((parent_id, node_id)) def _format_trace_label( self, operator_class: str, operator_type: str, step: Any ) -> str: """Format a label for display from trace step info. Creates a readable label by preferring operator_class when available, with fallbacks to operator_type or step index. Args: operator_class: Class name from trace (may be 'dict', 'list', etc.) operator_type: Operator type from trace step: The ExecutionStep object Returns: Human-readable label string """ step_index = getattr(step, 'step_index', 0) op_type_lower = operator_type.lower() if operator_type else '' branch_name = getattr(step, 'branch_name', '') or '' branch_path = getattr(step, 'branch_path', []) or [] substep_idx = getattr(step, 'substep_index', None) # Use branch_path to derive branch context if branch_name is empty if not branch_name and branch_path: # Format as "B0", "B1", etc. for compact display branch_name = f"B{branch_path[0]}" # Generic Python types to avoid using directly generic_types = {'dict', 'list', 'tuple', 'str', 'int', 'config', 'NoneType', ''} # Special handling for merge and branch - include the mode/strategy if op_type_lower == 'merge': if operator_class and operator_class.lower() not in generic_types: # operator_class is something like 'predictions' or 'features' return f"Merge ({operator_class})" return "Merge" if op_type_lower == 'branch': if operator_class and operator_class.lower() not in generic_types: return f"Branch: {operator_class}" return "Branch" if op_type_lower == 'source_branch': return "Source Branch" if op_type_lower == 'merge_sources': if operator_class and operator_class.lower() not in generic_types: return f"Merge Sources ({operator_class})" return "Merge Sources" # If operator_class is meaningful (not a generic Python type), use it if operator_class and operator_class.lower() not in generic_types: # Shorten long operator class names if needed class_label = operator_class if len(class_label) > 25: class_label = class_label[:22] + "..." # For branch substeps, prepend abbreviated branch name if branch_name and substep_idx is not None: # Abbreviate branch name for compact display short_branch = branch_name[:12] + ".." if len(branch_name) > 14 else branch_name return f"[{short_branch}] {class_label}" return class_label # Fallback to formatted operator_type type_labels = { 'preprocessing': 'Preprocessing', 'y_processing': 'Y Processing', 'feature_augmentation': 'Feature Aug', 'sample_augmentation': 'Sample Aug', 'concat_transform': 'Concat', 'model': 'Model', 'meta_model': 'Meta Model', 'splitter': 'Splitter', 'branch': 'Branch', 'merge': 'Merge', 'source_branch': 'Source Branch', 'merge_sources': 'Merge Sources', 'transform': 'Transform', 'operator': 'Operator', 'config': 'Config', } if operator_type: label = type_labels.get(op_type_lower, operator_type.title()) # For branch substeps with generic type, still show branch context if branch_name and substep_idx is not None: short_branch = branch_name[:12] + ".." if len(branch_name) > 14 else branch_name return f"[{short_branch}] {label}" return label return f"Step {step_index}" def _classify_operator_type(self, op_type: str, op_class: str) -> str: """Classify operator into a node type for coloring. Args: op_type: Operator type from trace op_class: Operator class name Returns: Node type string for coloring """ op_type_lower = op_type.lower() op_class_lower = op_class.lower() if 'model' in op_type_lower or 'meta_model' in op_type_lower: return 'model' elif 'splitter' in op_type_lower or 'fold' in op_class_lower or 'split' in op_class_lower: return 'splitter' elif 'branch' in op_type_lower: return 'branch' elif 'merge' in op_type_lower: return 'merge' elif 'y_processing' in op_type_lower: return 'y_processing' elif 'feature_augmentation' in op_type_lower: return 'feature_augmentation' elif 'sample_augmentation' in op_type_lower: return 'sample_augmentation' elif 'concat_transform' in op_type_lower: return 'concat_transform' elif 'source_branch' in op_type_lower: return 'source_branch' elif 'merge_sources' in op_type_lower: return 'merge_sources' else: return 'preprocessing' def _build_dag(self, initial_shape: Optional[Tuple[int, int, int]] = None) -> None: """Build the DAG from pipeline steps. Args: initial_shape: Initial dataset shape """ self.nodes.clear() self.edges.clear() if not self.pipeline_steps: # Try to infer from predictions if self.predictions: self._build_dag_from_predictions() return # Default initial shape current_shape = initial_shape or (100, 1, 1000) # Create input node input_node = PipelineNode( id="input", step_index=0, label="Dataset", node_type="input", shape_after=current_shape, features_shape=[current_shape], ) self.nodes["input"] = input_node # Track current node IDs for edge connections current_node_ids = ["input"] branch_stacks: List[List[str]] = [] # Stack of lists of node IDs per branch level step_index = 0 for step in self.pipeline_steps: step_index += 1 step_info = self._parse_step(step, step_index) if step_info is None: continue node_type = step_info['type'] label = step_info['label'] keyword = step_info.get('keyword', '') # Handle branching if node_type in ('branch', 'source_branch'): # Create branch node branch_node = PipelineNode( id=f"step_{step_index}_branch", step_index=step_index, label=label, node_type=node_type, shape_before=current_shape, shape_after=current_shape, features_shape=[current_shape], parent_ids=current_node_ids.copy(), ) self.nodes[branch_node.id] = branch_node # Add edges from current nodes to branch for parent_id in current_node_ids: self.edges.append((parent_id, branch_node.id)) # Create nodes for each branch branches = step_info.get('branches', {}) branch_node_ids = [] for branch_id, (branch_name, branch_steps) in enumerate(branches.items()): # Create branch entry node entry_id = f"step_{step_index}_b{branch_id}_entry" entry_label = branch_name if isinstance(branch_name, str) else f"Branch {branch_id}" entry_node = PipelineNode( id=entry_id, step_index=step_index, label=entry_label, node_type=node_type, branch_id=branch_id, branch_name=entry_label, shape_before=current_shape, shape_after=current_shape, features_shape=[current_shape], parent_ids=[branch_node.id], ) self.nodes[entry_id] = entry_node self.edges.append((branch_node.id, entry_id)) # Process branch substeps branch_current = [entry_id] branch_shape = current_shape for substep_idx, substep in enumerate(branch_steps): substep_info = self._parse_step(substep, step_index) if substep_info: substep_id = f"step_{step_index}_b{branch_id}_s{substep_idx}" new_substep_shape = self._estimate_shape_after(branch_shape, substep_info) sub_metadata = {} if substep_info.get('n_splits'): sub_metadata['n_splits'] = substep_info['n_splits'] substep_node = PipelineNode( id=substep_id, step_index=step_index, label=substep_info['label'], node_type=substep_info['type'], branch_id=branch_id, branch_name=entry_label, substep_index=substep_idx, shape_before=branch_shape, shape_after=new_substep_shape, features_shape=[new_substep_shape], parent_ids=branch_current.copy(), metadata=sub_metadata, ) self.nodes[substep_id] = substep_node for parent in branch_current: self.edges.append((parent, substep_id)) branch_current = [substep_id] branch_shape = substep_node.shape_after branch_node_ids.extend(branch_current) # Push branch context branch_stacks.append(branch_node_ids) current_node_ids = branch_node_ids elif node_type in ('merge', 'merge_sources'): # Create merge node new_merge_shape = self._estimate_merge_shape(current_shape, step_info) merge_node = PipelineNode( id=f"step_{step_index}_merge", step_index=step_index, label=label, node_type=node_type, shape_before=current_shape, shape_after=new_merge_shape, features_shape=[new_merge_shape], parent_ids=current_node_ids.copy(), ) self.nodes[merge_node.id] = merge_node # Add edges from all branch ends to merge for parent_id in current_node_ids: self.edges.append((parent_id, merge_node.id)) # Pop branch context if branch_stacks: branch_stacks.pop() current_node_ids = [merge_node.id] current_shape = merge_node.shape_after else: # Regular step node_id = f"step_{step_index}" new_shape = self._estimate_shape_after(current_shape, step_info) metadata = {'keyword': keyword} if keyword else {} if step_info.get('n_splits'): metadata['n_splits'] = step_info['n_splits'] node = PipelineNode( id=node_id, step_index=step_index, label=label, node_type=node_type, shape_before=current_shape, shape_after=new_shape, features_shape=[new_shape], parent_ids=current_node_ids.copy(), metadata=metadata, ) self.nodes[node_id] = node # Add edges from current nodes for parent_id in current_node_ids: self.edges.append((parent_id, node_id)) current_node_ids = [node_id] current_shape = new_shape def _parse_step(self, step: Any, step_index: int) -> Optional[Dict[str, Any]]: """Parse a pipeline step into structured info. Args: step: Pipeline step definition step_index: Step index Returns: Dictionary with step info or None if unrecognized """ # Handle None or empty if step is None: return None # Handle string steps (chart commands, etc.) if isinstance(step, str): if 'chart' in step.lower(): return {'type': 'chart', 'label': step} return {'type': 'default', 'label': step} # Handle class (not instance) if isinstance(step, type): class_name = step.__name__ return self._classify_operator(class_name, {}) # Handle instance (has __class__) if hasattr(step, '__class__') and not isinstance(step, dict): class_name = step.__class__.__name__ info = self._classify_operator(class_name, {}) if info['type'] == 'splitter': n_splits = getattr(step, 'n_splits', None) if n_splits: info['n_splits'] = n_splits return info # Handle dict steps if isinstance(step, dict): # Check for known keywords keywords = [ 'preprocessing', 'y_processing', 'feature_augmentation', 'sample_augmentation', 'concat_transform', 'branch', 'merge', 'source_branch', 'merge_sources', 'model', 'split', 'name', 'merge_predictions' ] for keyword in keywords: if keyword in step: return self._parse_keyword_step(keyword, step) # Check if it's a model dict if 'model' in step: return self._parse_keyword_step('model', step) # Generic dict step return {'type': 'default', 'label': str(list(step.keys())[0]) if step else '?'} # Handle list (could be a substep list) if isinstance(step, (list, tuple)): if len(step) == 1: return self._parse_step(step[0], step_index) return {'type': 'chain', 'label': f"[{len(step)} ops]"} return {'type': 'default', 'label': '?'} def _parse_keyword_step(self, keyword: str, step: Dict) -> Dict[str, Any]: """Parse a keyword-based step. Args: keyword: Step keyword step: Step dictionary Returns: Parsed step info """ value = step.get(keyword) if keyword == 'preprocessing': op_name = self._get_operator_name(value) return {'type': 'preprocessing', 'label': op_name, 'keyword': keyword} elif keyword == 'y_processing': op_name = self._get_operator_name(value) return {'type': 'y_processing', 'label': f"y: {op_name}", 'keyword': keyword} elif keyword == 'feature_augmentation': op_count = 1 if isinstance(value, list): op_count = len(value) ops = [self._get_operator_name(v) for v in value[:3]] label = "FA: " + ", ".join(ops) if len(value) > 3: label += f"... (+{len(value)-3})" else: label = f"FA: {self._get_operator_name(value)}" action = step.get('action', 'add') return { 'type': 'feature_augmentation', 'label': label, 'action': action, 'keyword': keyword, 'op_count': op_count } elif keyword == 'sample_augmentation': aug_count = 1 if isinstance(value, dict): transformers = value.get('transformers', []) count = value.get('count', 1) # Handle 'count' being a string '?' or similar try: aug_count = int(count) except (ValueError, TypeError): aug_count = 1 label = f"SA: {len(transformers)} aug ×{count}" else: label = "Sample Aug" return { 'type': 'sample_augmentation', 'label': label, 'keyword': keyword, 'aug_count': aug_count } elif keyword == 'concat_transform': if isinstance(value, list): ops = [self._get_operator_name(v) for v in value] label = "Concat: " + "+".join(ops) elif isinstance(value, dict) and 'operations' in value: ops = [self._get_operator_name(v) for v in value['operations']] label = "Concat: " + "+".join(ops) else: label = "Concat Transform" return {'type': 'concat_transform', 'label': label, 'keyword': keyword} elif keyword == 'branch': branches = {} if isinstance(value, dict): for k, v in value.items(): if not k.startswith('_'): branches[k] = v if isinstance(v, list) else [v] elif isinstance(value, list): for i, v in enumerate(value): branches[f"Branch {i}"] = v if isinstance(v, list) else [v] return {'type': 'branch', 'label': 'Branch', 'branches': branches, 'keyword': keyword} elif keyword == 'merge': merge_type = 'features' if value == 'features' else 'predictions' return {'type': 'merge', 'label': f"Merge ({merge_type})", 'merge_type': merge_type, 'keyword': keyword} elif keyword == 'merge_predictions': return {'type': 'merge', 'label': 'Merge Predictions', 'merge_type': 'predictions', 'keyword': keyword} elif keyword == 'source_branch': branches = {} if isinstance(value, dict): for k, v in value.items(): branches[str(k)] = v if isinstance(v, list) else [v] elif isinstance(value, list): for i, v in enumerate(value): branches[f"Source {i}"] = v if isinstance(v, list) else [v] return {'type': 'source_branch', 'label': 'Source Branch', 'branches': branches, 'keyword': keyword} elif keyword == 'merge_sources': strategy = value if isinstance(value, str) else 'concat' return {'type': 'merge_sources', 'label': f"Merge Sources ({strategy})", 'keyword': keyword} elif keyword == 'model': model_name = step.get('name', self._get_operator_name(value)) return {'type': 'model', 'label': model_name, 'keyword': keyword} elif keyword == 'split': splitter = value splitter_name = self._get_operator_name(splitter) n_splits = getattr(splitter, 'n_splits', None) return {'type': 'splitter', 'label': splitter_name, 'keyword': keyword, 'n_splits': n_splits} elif keyword == 'name': # Named step - look for model if 'model' in step: return {'type': 'model', 'label': step['name'], 'keyword': 'model'} return {'type': 'default', 'label': step['name']} return {'type': 'default', 'label': keyword, 'keyword': keyword} def _classify_operator(self, class_name: str, config: Dict) -> Dict[str, Any]: """Classify an operator by its class name. Args: class_name: Operator class name config: Operator configuration Returns: Step info dictionary """ # Splitters splitter_names = ['KFold', 'StratifiedKFold', 'GroupKFold', 'ShuffleSplit', 'StratifiedShuffleSplit', 'GroupShuffleSplit', 'LeaveOneOut', 'LeaveOneGroupOut', 'TimeSeriesSplit'] if class_name in splitter_names: return {'type': 'splitter', 'label': class_name} # Models model_indicators = ['Regressor', 'Classifier', 'Regression', 'SVC', 'SVR', 'LinearModel', 'Tree', 'Forest', 'Boost', 'Network', 'MLP', 'CNN', 'RNN', 'LSTM', 'Ridge', 'Lasso', 'Elastic', 'PLS', 'PCR', 'KNN', 'Naive', 'Bayes'] for indicator in model_indicators: if indicator in class_name: return {'type': 'model', 'label': class_name} # Scalers if 'Scaler' in class_name or 'Normalizer' in class_name: return {'type': 'preprocessing', 'label': class_name} # NIRS transforms nirs_transforms = ['SNV', 'StandardNormalVariate', 'MSC', 'MultiplicativeScatterCorrection', 'FirstDerivative', 'SecondDerivative', 'SavitzkyGolay', 'Detrend', 'Gaussian', 'SmoothSignal', 'Baseline'] if class_name in nirs_transforms: return {'type': 'preprocessing', 'label': class_name} # Decomposition if class_name in ['PCA', 'TruncatedSVD', 'NMF', 'ICA', 'FactorAnalysis']: return {'type': 'preprocessing', 'label': class_name} # Default return {'type': 'preprocessing', 'label': class_name} def _get_operator_name(self, op: Any) -> str: """Get a human-readable name for an operator. Args: op: Operator instance or class Returns: Operator name string """ if op is None: return "None" if isinstance(op, str): return op if isinstance(op, type): return op.__name__ if hasattr(op, '__class__'): return op.__class__.__name__ return str(op)[:20] def _estimate_shape_after( self, shape_before: Tuple[int, int, int], step_info: Dict[str, Any] ) -> Tuple[int, int, int]: """Estimate the dataset shape after a step. Args: shape_before: Shape before the step (samples, processings, features) step_info: Step information Returns: Estimated shape after the step """ if shape_before is None: return (100, 1, 1000) samples, processings, features = shape_before step_type = step_info.get('type', 'default') if step_type == 'feature_augmentation': # Feature augmentation adds processings action = step_info.get('action', 'add') op_count = step_info.get('op_count', 1) if action == 'extend': # Adds new processings processings += op_count elif action == 'replace': # Replaces processings processings = op_count else: # add # Multiplies processings (each existing processing gets N new versions) processings *= op_count elif step_type == 'sample_augmentation': # Sample augmentation adds samples aug_count = step_info.get('aug_count', 1) # Usually adds N augmented samples per original sample # So total = original + (original * aug_count) samples = samples + (samples * aug_count) elif step_type == 'concat_transform': # Concat reduces features features = 50 # Estimate for PCA/SVD concat elif step_type == 'model': # Model doesn't change shape pass elif step_type == 'splitter': # Splitter creates folds but doesn't change shape pass return (samples, processings, features) def _estimate_merge_shape( self, shape_before: Tuple[int, int, int], step_info: Dict[str, Any] ) -> Tuple[int, int, int]: """Estimate shape after merge. Args: shape_before: Shape before merge step_info: Merge step info Returns: Estimated shape after merge """ samples, processings, features = shape_before merge_type = step_info.get('merge_type', 'predictions') if merge_type == 'features': # Concatenate features from branches features *= 3 # Estimate for 3 branches processings = 1 else: # predictions # Stack predictions as features features = 3 # 1 prediction per branch (estimate 3 branches) processings = 1 return (samples, processings, features) def _build_dag_from_predictions(self) -> None: """Build DAG from predictions object when no pipeline steps provided.""" if not self.predictions: return # Try to get execution info from predictions try: preprocessings = self.predictions.get_unique_values('preprocessings') models = self.predictions.get_unique_values('model_name') branches = self.predictions.get_unique_values('branch_name') except (ValueError, KeyError): return # Create input node input_node = PipelineNode( id="input", step_index=0, label="Dataset", node_type="input", ) self.nodes["input"] = input_node current_id = "input" # Add preprocessing summary if preprocessings: pp_list = [p for p in preprocessings if p] if pp_list: pp_node = PipelineNode( id="preprocessing", step_index=1, label=f"Preprocessing\n({len(pp_list)} views)", node_type="preprocessing", parent_ids=[current_id], ) self.nodes[pp_node.id] = pp_node self.edges.append((current_id, pp_node.id)) current_id = pp_node.id # Add branches if present if branches and len([b for b in branches if b]) > 1: branch_node = PipelineNode( id="branches", step_index=2, label=f"Branches\n({len(branches)})", node_type="branch", parent_ids=[current_id], ) self.nodes[branch_node.id] = branch_node self.edges.append((current_id, branch_node.id)) current_id = branch_node.id # Add models if models: model_list = [m for m in models if m] model_node = PipelineNode( id="models", step_index=3, label=f"Models\n({', '.join(model_list[:3])}{'...' if len(model_list) > 3 else ''})", node_type="model", parent_ids=[current_id], ) self.nodes[model_node.id] = model_node self.edges.append((current_id, model_node.id)) def _estimate_node_width(self, node: PipelineNode) -> float: """Estimate node width based on label length.""" label_lines = [node.label] if self._show_shapes: label_lines.extend(self._format_shape_display(node)) # Add score line estimate if node.metadata.get('best_score') is not None and node.node_type == 'model': label_lines.append("★ 0.00") max_text_len = max([len(l) for l in label_lines]) if label_lines else 0 # Increased factor from 0.16 to 0.22 for better fit calc_width = 2.0 + max_text_len * 0.22 return max(self._node_width, calc_width) def _compute_layout(self) -> Dict[str, Dict[str, Any]]: """Compute node positions using topological sort and layering. Branches maintain their column positions throughout their execution, with nodes stacked vertically in their assigned columns. Returns: Dictionary mapping node IDs to position info """ layout = {} if not self.nodes: return layout # Compute layers using topological sort layers = self._compute_layers() # Calculate dynamic spacing based on max node width max_width = 0 for node in self.nodes.values(): w = self._estimate_node_width(node) if w > max_width: max_width = w # Dynamic spacing with minimum x_spacing = max(2.8, max_width + 0.5) y_spacing = 1.8 # Assign fixed column positions for each branch # Collect all unique branch_ids branch_ids = set() for node in self.nodes.values(): if node.branch_id is not None: branch_ids.add(node.branch_id) # Sort branch IDs and create column mapping sorted_branches = sorted(branch_ids) n_branches = len(sorted_branches) branch_column = {bid: i for i, bid in enumerate(sorted_branches)} for layer_idx, layer_nodes in enumerate(layers): y = -layer_idx * y_spacing # Separate nodes into branched and non-branched branched_nodes = [(nid, self.nodes[nid].branch_id) for nid in layer_nodes if self.nodes[nid].branch_id is not None] unbranched_nodes = [nid for nid in layer_nodes if self.nodes[nid].branch_id is None] # Position branched nodes in their fixed columns if n_branches > 0: branch_x_start = -(n_branches - 1) * x_spacing / 2 for node_id, bid in branched_nodes: col = branch_column[bid] x = branch_x_start + col * x_spacing layout[node_id] = { 'x': x, 'y': y, 'node': self.nodes[node_id], } # Position unbranched nodes centered if unbranched_nodes: n_unbranched = len(unbranched_nodes) x_start = -(n_unbranched - 1) * x_spacing / 2 for i, node_id in enumerate(unbranched_nodes): x = x_start + i * x_spacing layout[node_id] = { 'x': x, 'y': y, 'node': self.nodes[node_id], } return layout def _compute_layers(self) -> List[List[str]]: """Compute node layers using topological ordering. Returns: List of lists, where each inner list contains node IDs for that layer """ # Build adjacency and in-degree in_degree = {node_id: 0 for node_id in self.nodes} adj = defaultdict(list) for from_id, to_id in self.edges: adj[from_id].append(to_id) in_degree[to_id] += 1 # Find roots (nodes with no parents) roots = [node_id for node_id, degree in in_degree.items() if degree == 0] # BFS layering layers = [] current_layer = roots visited = set() while current_layer: layers.append(current_layer) visited.update(current_layer) next_layer = [] for node_id in current_layer: for child_id in adj[node_id]: in_degree[child_id] -= 1 if in_degree[child_id] == 0 and child_id not in visited: next_layer.append(child_id) current_layer = next_layer return layers def _format_shape_display(self, node: PipelineNode) -> List[str]: """Format shape information for display in a node. Shows: - For single source: S×P×F (samples × processings × features) - For multi-source: Source count + total features - Always shows 2D layout shape when available - Model scores when available Args: node: The pipeline node with shape info Returns: List of formatted shape strings """ shape_lines = [] show_features_shape = True # For model nodes, prefer concise 2D shape (stacked features/preds) to avoid misleading 3D shapes if node.node_type == 'model': show_features_shape = False # Format: (samples, [p, f], [p, f]...) if show_features_shape and node.features_shape: parts = [] n_samples = node.features_shape[0][0] parts.append(str(n_samples)) for s, p, f in node.features_shape: parts.append(f"[{p}, {f}]") shape_str = f"({', '.join(parts)})" shape_lines.append(shape_str) # Total 2D: (samples, features) if node.output_layout_shape: s, f = node.output_layout_shape shape_lines.append(f"2D: ({s}, {f})") elif node.input_layout_shape and not shape_lines: # Fallback to input shape if output not available s, f = node.input_layout_shape shape_lines.append(f"2D: ({s}, {f})") elif node.features_shape: # Compute 2D from features_shape as a last resort n_samples = node.features_shape[0][0] total_features = sum(p * f for (_, p, f) in node.features_shape) shape_lines.append(f"2D: ({n_samples}, {total_features})") # Fallback if no trace info but we have estimated shape if not shape_lines and node.shape_after: s, p, f = node.shape_after shape_lines.append(f"({s}, [{p}, {f}])") # Final fallback: show input shape if we have nothing if not shape_lines and node.shape_before: s, p, f = node.shape_before shape_lines.append(f"in: ({s}, [{p}, {f}])") # Score is displayed separately in _draw_nodes for model nodes # Only show score in shape lines for non-model nodes # if node.node_type != 'model' and node.metadata: # best_score = node.metadata.get('best_score') # if best_score is not None: # if isinstance(best_score, float): # shape_lines.append(f"score: {best_score:.4f}") # else: # shape_lines.append(f"score: {best_score}") # Display fold info for splitters if node.node_type == 'splitter': n_splits = node.metadata.get('n_splits') if n_splits: shape_lines.append(f"{n_splits} folds") return shape_lines def _draw_nodes( self, ax: Axes, layout: Dict[str, Dict[str, Any]], show_shapes: bool ) -> None: """Draw nodes on the diagram. Args: ax: Matplotlib axes layout: Node layout show_shapes: Whether to show shape info """ for node_id, pos_info in layout.items(): x, y = pos_info['x'], pos_info['y'] node = pos_info['node'] # Get style fill_color, border_color = self.NODE_STYLES.get(node.node_type, self.NODE_STYLES['default']) # Build label with improved shape display label_lines = [node.label] if show_shapes: shape_lines = self._format_shape_display(node) label_lines.extend(shape_lines) # Add score for model nodes score = node.metadata.get('best_score') if score is not None and node.node_type == 'model': label_lines.append(f"★ {score:.2f}") n_lines = len(label_lines) # Calculate width based on text length max_text_len = max([len(l) for l in label_lines]) if label_lines else 0 # Base width + char width factor calc_width = 2.0 + max_text_len * 0.22 box_width = max(self._node_width, calc_width) # Adjust height for multi-line line_height = 0.35 box_height = self._node_height + (n_lines - 1) * line_height # Store dimensions for edge drawing pos_info['width'] = box_width pos_info['height'] = box_height # Draw shadow (offset) shadow_offset = 0.05 shadow = FancyBboxPatch( (x - box_width / 2 + shadow_offset, y - box_height / 2 - shadow_offset), box_width, box_height, boxstyle="round,pad=0.1,rounding_size=0.2", facecolor='#000000', edgecolor='none', alpha=0.1, zorder=1 ) ax.add_patch(shadow) # Draw node box rect = FancyBboxPatch( (x - box_width / 2, y - box_height / 2), box_width, box_height, boxstyle="round,pad=0.1,rounding_size=0.2", facecolor=fill_color, edgecolor=border_color, linewidth=1.5, alpha=1.0, zorder=2 ) ax.add_patch(rect) # Calculate vertical positions for text lines start_y = y + (n_lines - 1) * line_height / 2 # Draw operator label (first line) ax.text( x, start_y, node.label, ha='center', va='center', fontsize=self._fontsize, color='#263238', # Dark Blue Grey fontweight='bold', zorder=3 ) # Draw additional info (shapes, score) if len(label_lines) > 1: for i, line in enumerate(label_lines[1:], 1): line_y = start_y - i * line_height # Use different style for score line if line.startswith('★'): ax.text( x, line_y, line, ha='center', va='center', fontsize=self._fontsize, color='#D32F2F', # Red fontweight='bold', zorder=3 ) else: ax.text( x, line_y, line, ha='center', va='center', fontsize=self._fontsize - 1.5, color='#546E7A', # Blue Grey fontfamily='monospace', fontweight='normal', zorder=3 ) def _draw_edges( self, ax: Axes, layout: Dict[str, Dict[str, Any]] ) -> None: """Draw edges connecting nodes. Args: ax: Matplotlib axes layout: Node layout """ for from_id, to_id in self.edges: if from_id not in layout or to_id not in layout: continue from_pos = layout[from_id] to_pos = layout[to_id] # Get node dimensions (use defaults if not computed yet) from_height = from_pos.get('height', self._node_height) to_height = to_pos.get('height', self._node_height) # Calculate connection points from_x, from_y = from_pos['x'], from_pos['y'] - from_height / 2 to_x, to_y = to_pos['x'], to_pos['y'] + to_height / 2 # Determine curve based on horizontal offset dx = to_x - from_x rad = 0.0 if abs(dx) > 0.1: # Curve slightly for non-vertical connections rad = 0.1 if dx > 0 else -0.1 # Draw arrow ax.annotate( '', xy=(to_x, to_y), xytext=(from_x, from_y), arrowprops=dict( arrowstyle='-|>', color='#546E7A', # Blue Grey linewidth=1.5, shrinkA=0, shrinkB=0, connectionstyle=f'arc3,rad={rad}', mutation_scale=15, ), zorder=0 ) def _get_bounds( self, layout: Dict[str, Dict[str, Any]] ) -> Tuple[float, float, float, float]: """Get bounding box for the diagram. Args: layout: Node layout Returns: Tuple of (x_min, x_max, y_min, y_max) """ if not layout: return -1, 1, -1, 1 x_coords = [p['x'] for p in layout.values()] y_coords = [p['y'] for p in layout.values()] x_min = min(x_coords) - self._node_width x_max = max(x_coords) + self._node_width y_min = min(y_coords) - self._node_height * 2 y_max = max(y_coords) + self._node_height * 2 return x_min, x_max, y_min, y_max def _add_legend(self, ax: Axes) -> None: """Add a legend showing node type colors. Args: ax: Matplotlib axes """ legend_items = [ ('Input/Output', self.NODE_STYLES['input']), ('Preprocessing', self.NODE_STYLES['preprocessing']), ('Feature Aug', self.NODE_STYLES['feature_augmentation']), ('Sample Aug', self.NODE_STYLES['sample_augmentation']), ('Y Processing', self.NODE_STYLES['y_processing']), ('Splitter', self.NODE_STYLES['splitter']), ('Branch/Merge', self.NODE_STYLES['branch']), ('Model', self.NODE_STYLES['model']), ] patches = [] for label, (fill, border) in legend_items: patch = mpatches.Patch( facecolor=fill, edgecolor=border, label=label, linewidth=1.0 ) patches.append(patch) ax.legend( handles=patches, loc='upper right', fontsize=self._fontsize - 1, framealpha=0.95, edgecolor='#CFD8DC', ncol=2, )
[docs] def plot_pipeline_diagram( pipeline_steps: Optional[List[Any]] = None, predictions: Any = None, show_shapes: bool = True, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, initial_shape: Optional[Tuple[int, int, int]] = None, config: Optional[Dict[str, Any]] = None, execution_trace: Any = None ) -> Figure: """Convenience function to create a pipeline diagram. Args: pipeline_steps: List of pipeline step definitions predictions: Optional Predictions object with execution data show_shapes: Whether to show shape info in nodes figsize: Figure size tuple title: Optional title for the diagram initial_shape: Initial dataset shape (samples, processings, features) config: Additional configuration dict execution_trace: Optional ExecutionTrace object Returns: matplotlib Figure object Example: >>> from nirs4all.visualization.pipeline_diagram import plot_pipeline_diagram >>> fig = plot_pipeline_diagram(pipeline, initial_shape=(189, 1, 2151)) >>> fig.savefig('pipeline_diagram.png') """ cfg = config or {} diagram = PipelineDiagram(pipeline_steps, predictions, execution_trace=execution_trace, config=cfg) if execution_trace: diagram._build_dag_from_trace() return diagram.render( show_shapes=show_shapes, figsize=figsize, title=title, initial_shape=initial_shape, )
# Backward compatibility alias BranchDiagram = PipelineDiagram plot_branch_diagram = plot_pipeline_diagram