Source code for nirs4all.visualization.analysis.branch

"""
Branch Analysis - Statistical analysis and comparison for pipeline branches.

This module provides tools for analyzing and comparing performance
across different pipeline branches.

Features:
- Branch summary statistics (mean, std, min, max)
- Statistical significance testing between branches
- DataFrame and LaTeX export for publications
- Nested branch analysis support

Example:
    >>> from nirs4all.visualization.analysis.branch import BranchAnalyzer
    >>> analyzer = BranchAnalyzer(predictions)
    >>> summary = analyzer.summary(metrics=['rmse', 'r2'])
    >>> print(summary.to_markdown())
"""

from typing import Any, Dict, List, Optional, Tuple, Union
from collections import defaultdict
import numpy as np

try:
    import pandas as pd
    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

try:
    from scipy import stats
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False


[docs] class BranchSummary: """Branch summary statistics container with export capabilities. Provides DataFrame-like access and export to markdown, LaTeX, and CSV. Attributes: data: List of dictionaries with branch statistics. metrics: List of metrics computed. columns: Column names in order. """ def __init__( self, data: List[Dict[str, Any]], metrics: List[str] ): """Initialize BranchSummary. Args: data: List of dictionaries, one per branch. metrics: List of metrics that were computed. """ self.data = data self.metrics = metrics self._df: Optional['pd.DataFrame'] = None # Define column order base_cols = ['branch_name', 'branch_id', 'count'] metric_cols = [] for metric in metrics: metric_cols.extend([ f'{metric}_mean', f'{metric}_std', f'{metric}_min', f'{metric}_max' ]) self.columns = base_cols + metric_cols
[docs] def to_dataframe(self) -> 'pd.DataFrame': """Convert to pandas DataFrame. Returns: pandas DataFrame with branch statistics. Raises: ImportError: If pandas is not installed. """ if not HAS_PANDAS: raise ImportError("pandas is required for DataFrame export") if self._df is None: self._df = pd.DataFrame(self.data) # Reorder columns cols = [c for c in self.columns if c in self._df.columns] self._df = self._df[cols] return self._df
[docs] def to_dict(self) -> Dict[str, Dict[str, Any]]: """Convert to dictionary keyed by branch name. Returns: Dictionary mapping branch_name to statistics. """ return {row['branch_name']: row for row in self.data}
[docs] def to_markdown( self, precision: int = 3, include_std: bool = True ) -> str: """Export as markdown table. Args: precision: Decimal places for floating point values. include_std: If True, include std columns. Returns: Markdown-formatted table string. """ if not self.data: return "No branch data available" # Determine columns to include cols = ['branch_name', 'branch_id', 'count'] for metric in self.metrics: cols.append(f'{metric}_mean') if include_std: cols.append(f'{metric}_std') # Filter columns that exist cols = [c for c in cols if c in self.data[0]] # Build header header = '| ' + ' | '.join(cols) + ' |' separator = '|' + '|'.join(['---'] * len(cols)) + '|' # Build rows rows = [] for row in self.data: values = [] for col in cols: val = row.get(col, '') if isinstance(val, float): if np.isnan(val): values.append('N/A') else: values.append(f'{val:.{precision}f}') else: values.append(str(val) if val is not None else '') rows.append('| ' + ' | '.join(values) + ' |') return '\n'.join([header, separator] + rows)
[docs] def to_latex( self, caption: str = "Branch Performance Comparison", label: str = "tab:branch_comparison", precision: int = 3, include_std: bool = True, mean_std_combined: bool = True ) -> str: """Export as LaTeX table for publications. Args: caption: Table caption. label: LaTeX label for referencing. precision: Decimal places for floating point values. include_std: If True, include std columns. mean_std_combined: If True, format as "mean ± std". Returns: LaTeX-formatted table string. """ if not self.data: return "% No branch data available" # Determine columns if mean_std_combined and include_std: display_cols = ['Branch', 'ID', 'N'] for metric in self.metrics: display_cols.append(metric.upper()) else: display_cols = ['Branch', 'ID', 'N'] for metric in self.metrics: display_cols.append(f'{metric}_mean') if include_std: display_cols.append(f'{metric}_std') # Build LaTeX table n_cols = len(display_cols) col_spec = 'l' + 'c' * (n_cols - 1) lines = [ r'\begin{table}[htbp]', r'\centering', f'\\caption{{{caption}}}', f'\\label{{{label}}}', f'\\begin{{tabular}}{{{col_spec}}}', r'\toprule', ] # Header header = ' & '.join(display_cols) + r' \\' lines.append(header) lines.append(r'\midrule') # Data rows for row in self.data: values = [ self._escape_latex(str(row.get('branch_name', ''))), str(row.get('branch_id', '')), str(row.get('count', '')), ] for metric in self.metrics: mean_val = row.get(f'{metric}_mean') std_val = row.get(f'{metric}_std') if mean_val is None or np.isnan(mean_val): values.append('--') elif mean_std_combined and include_std: if std_val is None or np.isnan(std_val): values.append(f'{mean_val:.{precision}f}') else: values.append(f'${mean_val:.{precision}f} \\pm {std_val:.{precision}f}$') else: values.append(f'{mean_val:.{precision}f}') if include_std: if std_val is None or np.isnan(std_val): values.append('--') else: values.append(f'{std_val:.{precision}f}') lines.append(' & '.join(values) + r' \\') lines.extend([ r'\bottomrule', r'\end{tabular}', r'\end{table}', ]) return '\n'.join(lines)
[docs] def to_csv( self, path: str, precision: int = 6 ) -> None: """Export to CSV file. Args: path: Output file path. precision: Decimal places for floating point values. """ df = self.to_dataframe() df.to_csv(path, index=False, float_format=f'%.{precision}f')
@staticmethod def _escape_latex(text: str) -> str: """Escape special LaTeX characters. Args: text: Input string. Returns: LaTeX-safe string. """ replacements = [ ('_', r'\_'), ('&', r'\&'), ('%', r'\%'), ('#', r'\#'), ('{', r'\{'), ('}', r'\}'), ] for old, new in replacements: text = text.replace(old, new) return text
[docs] def __repr__(self) -> str: """String representation.""" return self.to_markdown()
[docs] def __len__(self) -> int: """Number of branches.""" return len(self.data)
[docs] def __getitem__(self, key: Union[int, str]) -> Dict[str, Any]: """Get branch by index or name. Args: key: Integer index or branch name string. Returns: Dictionary with branch statistics. """ if isinstance(key, int): return self.data[key] else: for row in self.data: if row.get('branch_name') == key: return row raise KeyError(f"Branch '{key}' not found")
[docs] class BranchAnalyzer: """Analyze and compare performance across pipeline branches. Provides statistical analysis, hypothesis testing, and comparison tools for branched pipeline results. Attributes: predictions: Predictions object containing prediction data. """ def __init__(self, predictions): """Initialize BranchAnalyzer. Args: predictions: Predictions object with branch metadata. """ self.predictions = predictions
[docs] def summary( self, metrics: Optional[List[str]] = None, partition: str = 'test', aggregate: Optional[str] = None ) -> BranchSummary: """Generate summary statistics for each branch. Computes mean, std, min, max for each metric across branches. Args: metrics: List of metrics to compute (default: ['rmse', 'r2']). partition: Partition to compute metrics from (default: 'test'). aggregate: If provided, aggregate predictions by this column before computing statistics. Returns: BranchSummary object with statistics. """ if metrics is None: metrics = ['rmse', 'r2'] # Get all predictions, optionally aggregated all_preds = self._get_predictions(aggregate=aggregate) if not all_preds: return BranchSummary([], metrics) # Group by branch branch_groups: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for pred in all_preds: branch_name = pred.get('branch_name', 'no_branch') branch_groups[branch_name].append(pred) # Compute statistics for each branch summary_data = [] for branch_name, preds in sorted(branch_groups.items()): branch_id = preds[0].get('branch_id') if preds else None row = { 'branch_name': branch_name, 'branch_id': branch_id, 'count': len(preds), } # Collect scores for each metric for metric in metrics: scores = self._collect_scores(preds, metric, partition) if scores: row[f'{metric}_mean'] = float(np.mean(scores)) row[f'{metric}_std'] = float(np.std(scores)) row[f'{metric}_min'] = float(np.min(scores)) row[f'{metric}_max'] = float(np.max(scores)) else: row[f'{metric}_mean'] = np.nan row[f'{metric}_std'] = np.nan row[f'{metric}_min'] = np.nan row[f'{metric}_max'] = np.nan summary_data.append(row) return BranchSummary(summary_data, metrics)
[docs] def compare( self, branch1: Union[str, int], branch2: Union[str, int], metric: str = 'rmse', partition: str = 'test', test: str = 'ttest' ) -> Dict[str, Any]: """Statistical comparison between two branches. Performs hypothesis testing to determine if there's a significant difference between two branches. Args: branch1: First branch name or ID. branch2: Second branch name or ID. metric: Metric to compare (default: 'rmse'). partition: Partition for scores (default: 'test'). test: Statistical test ('ttest', 'wilcoxon', 'mannwhitney'). Returns: Dictionary with: - statistic: Test statistic - p_value: P-value - significant: Boolean at alpha=0.05 - branch1_mean: Mean of branch1 - branch2_mean: Mean of branch2 - effect_size: Cohen's d effect size Raises: ImportError: If scipy is not available. ValueError: If branches not found or insufficient data. """ if not HAS_SCIPY: raise ImportError("scipy is required for statistical testing") # Get predictions for each branch preds1 = self._get_branch_predictions(branch1) preds2 = self._get_branch_predictions(branch2) if not preds1 or not preds2: raise ValueError("One or both branches have no predictions") # Collect scores scores1 = self._collect_scores(preds1, metric, partition) scores2 = self._collect_scores(preds2, metric, partition) if len(scores1) < 2 or len(scores2) < 2: raise ValueError("Insufficient data for statistical testing") # Perform test if test == 'ttest': result = stats.ttest_ind(scores1, scores2) elif test == 'wilcoxon': # Wilcoxon requires paired samples min_len = min(len(scores1), len(scores2)) result = stats.wilcoxon(scores1[:min_len], scores2[:min_len]) elif test == 'mannwhitney': result = stats.mannwhitneyu(scores1, scores2, alternative='two-sided') else: raise ValueError(f"Unknown test: {test}") statistic = result.statistic # type: ignore[union-attr] p_value = result.pvalue # type: ignore[union-attr] # Compute effect size (Cohen's d) var1 = (len(scores1) - 1) * np.var(scores1) var2 = (len(scores2) - 1) * np.var(scores2) n_total = len(scores1) + len(scores2) - 2 pooled_std = np.sqrt((var1 + var2) / n_total) effect_size = ( (np.mean(scores1) - np.mean(scores2)) / pooled_std if pooled_std > 0 else 0 ) return { 'statistic': float(statistic), 'p_value': float(p_value), 'significant': float(p_value) < 0.05, 'branch1_mean': float(np.mean(scores1)), 'branch2_mean': float(np.mean(scores2)), 'branch1_std': float(np.std(scores1)), 'branch2_std': float(np.std(scores2)), 'effect_size': float(effect_size), 'test': test, 'n1': len(scores1), 'n2': len(scores2), }
[docs] def rank_branches( self, metric: str = 'rmse', partition: str = 'test', ascending: Optional[bool] = None ) -> List[Dict[str, Any]]: """Rank branches by mean performance. Args: metric: Metric to rank by (default: 'rmse'). partition: Partition for scores (default: 'test'). ascending: Sort order. If None, auto-detect based on metric. Returns: List of dicts with branch_name, mean, std, rank. """ summary = self.summary(metrics=[metric], partition=partition) # Determine sort order if ascending is None: higher_better_metrics = [ 'accuracy', 'balanced_accuracy', 'r2', 'f1', 'precision', 'recall', 'specificity', 'roc_auc' ] ascending = metric.lower() not in higher_better_metrics # Sort by mean ranked = sorted( summary.data, key=lambda x: x.get(f'{metric}_mean', float('inf') if ascending else float('-inf')), reverse=not ascending ) # Add rank for i, row in enumerate(ranked): row['rank'] = i + 1 return ranked
[docs] def pairwise_comparison( self, metric: str = 'rmse', partition: str = 'test', test: str = 'ttest' ) -> 'pd.DataFrame': """Compute pairwise statistical comparisons between all branches. Args: metric: Metric to compare (default: 'rmse'). partition: Partition for scores (default: 'test'). test: Statistical test to use. Returns: DataFrame with p-values for all branch pairs. Raises: ImportError: If pandas or scipy not available. """ if not HAS_PANDAS: raise ImportError("pandas is required for pairwise comparison") branches = self.get_branch_names() n = len(branches) # Initialize matrix p_values = np.ones((n, n)) for i, b1 in enumerate(branches): for j, b2 in enumerate(branches): if i < j: try: result = self.compare(b1, b2, metric, partition, test) p_values[i, j] = result['p_value'] p_values[j, i] = result['p_value'] except ValueError: pass return pd.DataFrame(p_values, index=branches, columns=branches)
[docs] def get_branch_names(self) -> List[str]: """Get list of unique branch names. Returns: List of branch names. """ try: names = self.predictions.get_unique_values('branch_name') return [n for n in names if n is not None] except (ValueError, KeyError): return []
[docs] def get_branch_ids(self) -> List[int]: """Get list of unique branch IDs. Returns: List of branch IDs. """ try: ids = self.predictions.get_unique_values('branch_id') return sorted([int(x) for x in ids if x is not None]) except (ValueError, KeyError): return []
def _get_predictions( self, aggregate: Optional[str] = None ) -> List[Dict[str, Any]]: """Get all predictions, optionally aggregated. Args: aggregate: If provided, aggregate by this column. Returns: List of prediction dictionaries. """ n = self.predictions.num_predictions if n == 0: return [] if aggregate: # Use top for aggregation return list(self.predictions.top( n=n, rank_metric='rmse', # Arbitrary, just need all preds aggregate=aggregate, aggregate_partitions=True )) else: # Use filter_predictions for regular access return self.predictions.filter_predictions() def _get_branch_predictions( self, branch: Union[str, int] ) -> List[Dict[str, Any]]: """Get predictions for a specific branch. Args: branch: Branch name (str) or ID (int). Returns: List of prediction dictionaries. """ if isinstance(branch, int): return self.predictions.filter_predictions(branch_id=branch) else: return self.predictions.filter_predictions(branch_name=branch) def _collect_scores( self, predictions: List[Dict[str, Any]], metric: str, partition: str ) -> List[float]: """Collect scores from predictions. Args: predictions: List of prediction dictionaries. metric: Metric to extract. partition: Partition to get scores from. Returns: List of score values. """ from nirs4all.core import metrics as evaluator scores = [] for pred in predictions: score = None # Try 1: partitions dict (from top method) partitions = pred.get('partitions', {}) if partitions and isinstance(partitions, dict): partition_data = partitions.get(partition, {}) if isinstance(partition_data, dict): score = partition_data.get(metric) # Try 2: {partition}_score fields (from filter_predictions) if score is None: # For MSE/RMSE, the field might be named val_score, test_score, train_score partition_score_key = f'{partition}_score' score_field = pred.get(partition_score_key) if score_field is not None: # Check if the metric matches what's stored pred_metric = pred.get('metric', '').lower() if pred_metric == metric.lower(): score = score_field # Try 3: scores dict if score is None: scores_dict = pred.get('scores', {}) if isinstance(scores_dict, dict): partition_scores = scores_dict.get(partition, {}) if isinstance(partition_scores, dict): score = partition_scores.get(metric) # Try 4: Compute from y_true/y_pred if score is None: y_true = pred.get('y_true') y_pred = pred.get('y_pred') # Only use if partition matches pred_partition = pred.get('partition', '') has_data = y_true is not None and y_pred is not None if has_data and pred_partition == partition: try: score = evaluator.eval(y_true, y_pred, metric) except Exception: continue # Validate score is a numeric type if score is not None and isinstance(score, (int, float, np.floating, np.integer)): score_float = float(score) if not np.isnan(score_float): scores.append(score_float) return scores