Source code for nirs4all.visualization.reports

"""
Tab Report Manager - Simplified tab report generation with formatting and saving

This module provides a clean interface for generating standardized tab-based CSV reports
using pre-calculated metrics and statistics from the evaluator module.
"""

from typing import Dict, Any, Optional, Tuple, Union
import numpy as np
import csv
import os
import io

# Import evaluator functions
import nirs4all.core.metrics as evaluator
from nirs4all.core.task_type import TaskType
from nirs4all.core.task_detection import detect_task_type
from nirs4all.core.logging import get_logger

logger = get_logger(__name__)


[docs] class TabReportManager: """Generate standardized tab-based CSV reports with pre-calculated data."""
[docs] @staticmethod def generate_best_score_tab_report( best_by_partition: Dict[str, Dict[str, Any]], aggregate: Optional[Union[str, bool]] = None, aggregate_method: Optional[str] = None, aggregate_exclude_outliers: bool = False ) -> Tuple[str, Optional[str]]: """ Generate best score tab report from partition data. Args: best_by_partition: Dict mapping partition names ('train', 'val', 'test') to prediction entries aggregate: Sample aggregation setting for computing additional aggregated metrics. - None (default): No aggregation, only raw scores displayed - True: Aggregate by y_true values (group by target) - str: Aggregate by specified metadata column (e.g., 'sample_id', 'ID') When set, both raw and aggregated scores are included in the output. Aggregated rows are marked with an asterisk (*). aggregate_method: Aggregation method for combining predictions. - None (default): Use 'mean' for regression, 'vote' for classification - 'mean': Average predictions within each group - 'median': Median prediction within each group - 'vote': Majority voting (for classification) aggregate_exclude_outliers: If True, exclude outliers using T² statistic before aggregation (default: False). Returns: Tuple of (formatted_string, csv_string_content) If aggregate is set, both raw and aggregated scores are included. """ if not best_by_partition: return "No prediction data available", None # Get task type from first available non-None prediction's metadata first_entry = next((v for v in best_by_partition.values() if v is not None), None) if first_entry is None: return "No prediction data available", None task_type = TabReportManager._get_task_type_from_entry(first_entry) # Extract n_features from metadata if available n_features = first_entry.get('n_features', 0) # Normalize aggregate parameter: True -> 'y', str -> str, None/False -> None effective_aggregate: Optional[str] = None if aggregate is True: effective_aggregate = 'y' elif isinstance(aggregate, str): effective_aggregate = aggregate # Calculate metrics and stats for each partition partitions_data = {} aggregated_partitions_data = {} for partition_name, entry in best_by_partition.items(): if partition_name in ['train', 'val', 'test'] and entry is not None: y_true = np.array(entry['y_true']) y_pred = np.array(entry['y_pred']) # Calculate raw (non-aggregated) metrics partitions_data[partition_name] = TabReportManager._calculate_partition_data( y_true, y_pred, task_type ) # Calculate aggregated metrics if requested if effective_aggregate: agg_result = TabReportManager._aggregate_predictions( y_true=y_true, y_pred=y_pred, aggregate=effective_aggregate, metadata=entry.get('metadata', {}), partition_name=partition_name, method=aggregate_method, exclude_outliers=aggregate_exclude_outliers ) if agg_result is not None: agg_y_true, agg_y_pred = agg_result aggregated_partitions_data[partition_name] = TabReportManager._calculate_partition_data( agg_y_true, agg_y_pred, task_type ) # Generate formatted string (matching PredictionHelpers format) formatted_string = TabReportManager._format_as_table_string( partitions_data, n_features, task_type, aggregated_partitions_data=aggregated_partitions_data, aggregate_column=effective_aggregate ) # Generate CSV string content csv_string = TabReportManager._format_as_csv_string( partitions_data, n_features, task_type, aggregated_partitions_data=aggregated_partitions_data, aggregate_column=effective_aggregate ) return formatted_string, csv_string
@staticmethod def _aggregate_predictions( y_true: np.ndarray, y_pred: np.ndarray, aggregate: str, metadata: Dict[str, Any], partition_name: str = "", method: Optional[str] = None, exclude_outliers: bool = False ) -> Optional[Tuple[np.ndarray, np.ndarray]]: """ Aggregate predictions by a group column. Args: y_true: True values array y_pred: Predicted values array aggregate: Group column name or 'y' to group by y_true metadata: Metadata dictionary containing group column partition_name: Partition name for error messages method: Aggregation method ('mean', 'median', 'vote'). Default is 'mean'. exclude_outliers: If True, exclude outliers using T² statistic before aggregation. Returns: Tuple of (aggregated_y_true, aggregated_y_pred) or None if aggregation fails """ from nirs4all.data.predictions import Predictions # Determine group IDs if aggregate == 'y': group_ids = y_true else: if aggregate not in metadata: logger.debug( f"Aggregation column '{aggregate}' not found in metadata for partition '{partition_name}'. " f"Available columns: {list(metadata.keys())}. Skipping aggregation." ) return None group_ids = np.asarray(metadata[aggregate]) if len(group_ids) != len(y_pred): logger.debug( f"Aggregation column '{aggregate}' length ({len(group_ids)}) doesn't match " f"predictions length ({len(y_pred)}) for partition '{partition_name}'. Skipping aggregation." ) return None try: result = Predictions.aggregate( y_pred=y_pred, group_ids=group_ids, y_true=y_true, method=method, exclude_outliers=exclude_outliers ) agg_y_true = result.get('y_true') agg_y_pred = result.get('y_pred') if agg_y_true is None or agg_y_pred is None: return None return agg_y_true, agg_y_pred except Exception as e: logger.debug(f"Aggregation failed for partition '{partition_name}': {e}") return None @staticmethod def _get_task_type_from_entry(entry: Dict[str, Any]) -> TaskType: """ Get task type from a prediction entry's metadata. Prioritizes stored task_type from metadata, falls back to detection only if metadata is missing (for backward compatibility with old predictions). Args: entry: Prediction entry dictionary Returns: TaskType: The task type for this prediction """ # First, try to get from metadata task_type_str = entry.get('task_type') if task_type_str: # Convert string to TaskType enum try: if isinstance(task_type_str, str): return TaskType(task_type_str) elif isinstance(task_type_str, TaskType): return task_type_str except (ValueError, KeyError): pass # Fall through to detection # Fallback: detect from y_true (for backward compatibility) logger.warning("task_type not found in prediction metadata, detecting from data") y_true = np.array(entry.get('y_true', [])) if len(y_true) == 0: return TaskType.REGRESSION return detect_task_type(y_true) @staticmethod def _format_as_table_string( partitions_data: Dict[str, Dict[str, Any]], n_features: int, task_type: TaskType, aggregated_partitions_data: Optional[Dict[str, Dict[str, Any]]] = None, aggregate_column: Optional[str] = None ) -> str: """Format the report data as a table string (matching PredictionHelpers format). Args: partitions_data: Dict of partition name to metrics data n_features: Number of features task_type: Task type (regression or classification) aggregated_partitions_data: Optional dict of partition name to aggregated metrics aggregate_column: Name of column used for aggregation (for footer note) Returns: Formatted table string """ if not partitions_data: return "No partition data available" # Prepare headers based on task type if task_type == TaskType.REGRESSION: headers = ['', 'Nsample', 'Nfeature', 'Mean', 'Median', 'Min', 'Max', 'SD', 'CV', 'R2', 'RMSE', 'MSE', 'SEP', 'MAE', 'RPD', 'Bias', 'Consistency'] else: # Classification is_binary = 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}) if is_binary: headers = ['', 'Nsample', 'Nfeatures', 'Accuracy', 'Bal. Acc', 'Precision', 'Bal. Prec', 'Recall', 'Bal. Rec', 'F1-score', 'Specificity', 'AUC'] else: headers = ['', 'Nsample', 'Nfeatures', 'Accuracy', 'Bal. Acc', 'Precision', 'Bal. Prec', 'Recall', 'Bal. Rec', 'F1-score', 'Specificity'] # Check if we have aggregated data has_aggregated = aggregated_partitions_data is not None and len(aggregated_partitions_data) > 0 # Prepare rows rows = [] # Add partition rows in order: val (Cross Val), train, test for partition_name in ['val', 'train', 'test']: if partition_name not in partitions_data: continue data = partitions_data[partition_name] display_name = "Cros Val" if partition_name == 'val' else partition_name.capitalize() # Add raw (non-aggregated) row row = TabReportManager._build_table_row( display_name, data, n_features, task_type, 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}) ) rows.append(row) # Add aggregated row if available if has_aggregated and aggregated_partitions_data is not None and partition_name in aggregated_partitions_data: agg_data = aggregated_partitions_data[partition_name] agg_row = TabReportManager._build_table_row( f"{display_name}*", agg_data, n_features, task_type, 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}), is_aggregated=True ) rows.append(agg_row) # Calculate column widths (minimum 6 characters per column) all_rows = [headers] + rows col_widths = [] for col_idx in range(len(headers)): max_width = max(len(str(all_rows[row_idx][col_idx])) for row_idx in range(len(all_rows))) col_widths.append(max(max_width, 6)) # Generate formatted table string lines = [] # Create separator line separator = '|' + '|'.join('-' * (width + 2) for width in col_widths) + '|' lines.append(separator) # Add header header_row = '|' + '|'.join(f" {str(headers[j]):<{col_widths[j]}} " for j in range(len(headers))) + '|' lines.append(header_row) lines.append(separator) # Add data rows for row in rows: data_row = '|' + '|'.join(f" {str(row[j]):<{col_widths[j]}} " for j in range(len(row))) + '|' lines.append(data_row) lines.append(separator) # Add footer note if aggregated if has_aggregated and aggregate_column: agg_label = "y (target values)" if aggregate_column == 'y' else aggregate_column lines.append(f"* Aggregated by {agg_label}") return '\n'.join(lines) @staticmethod def _build_table_row( display_name: str, data: Dict[str, Any], n_features: int, task_type: TaskType, is_binary: bool = False, is_aggregated: bool = False ) -> list: """Build a single table row for either raw or aggregated data. Args: display_name: Row label (e.g., 'Train', 'Test', 'Train*') data: Metrics dictionary for this partition n_features: Number of features task_type: Task type (regression or classification) is_binary: Whether this is binary classification is_aggregated: Whether this is an aggregated row (stats columns blank) Returns: List of formatted cell values """ if task_type == TaskType.REGRESSION: # For aggregated rows, skip descriptive stats (Mean, Median, Min, Max, SD, CV) # since they don't make sense after averaging predictions if is_aggregated: row = [ display_name, str(data.get('nsample', '')), str(n_features) if n_features > 0 else '', '', # Mean - blank for aggregated '', # Median - blank for aggregated '', # Min - blank for aggregated '', # Max - blank for aggregated '', # SD - blank for aggregated '', # CV - blank for aggregated f"{data.get('r2', ''):.3f}" if data.get('r2') else '', f"{data.get('rmse', ''):.3f}" if data.get('rmse') else '', f"{data.get('mse', ''):.3f}" if data.get('mse') else '', f"{data.get('sep', ''):.3f}" if data.get('sep') else '', f"{data.get('mae', ''):.3f}" if data.get('mae') else '', f"{data.get('rpd', ''):.2f}" if data.get('rpd') and data.get('rpd') != float('inf') else '', f"{data.get('bias', ''):.3f}" if data.get('bias') else '', f"{data.get('consistency', ''):.1f}" if data.get('consistency') else '' ] else: row = [ display_name, str(data.get('nsample', '')), str(n_features) if n_features > 0 else '', f"{data.get('mean', ''):.3f}" if data.get('mean') is not None else '', f"{data.get('median', ''):.3f}" if data.get('median') is not None else '', f"{data.get('min', ''):.3f}" if data.get('min') is not None else '', f"{data.get('max', ''):.3f}" if data.get('max') is not None else '', f"{data.get('sd', ''):.3f}" if data.get('sd') else '', f"{data.get('cv', ''):.3f}" if data.get('cv') else '', f"{data.get('r2', ''):.3f}" if data.get('r2') else '', f"{data.get('rmse', ''):.3f}" if data.get('rmse') else '', f"{data.get('mse', ''):.3f}" if data.get('mse') else '', f"{data.get('sep', ''):.3f}" if data.get('sep') else '', f"{data.get('mae', ''):.3f}" if data.get('mae') else '', f"{data.get('rpd', ''):.2f}" if data.get('rpd') and data.get('rpd') != float('inf') else '', f"{data.get('bias', ''):.3f}" if data.get('bias') else '', f"{data.get('consistency', ''):.1f}" if data.get('consistency') else '' ] else: # Classification row = [ display_name, str(data.get('nsample', '')), str(n_features) if n_features > 0 else '', f"{data.get('accuracy', ''):.3f}" if data.get('accuracy') else '', f"{data.get('balanced_accuracy', ''):.3f}" if data.get('balanced_accuracy') else '', f"{data.get('precision', ''):.3f}" if data.get('precision') else '', f"{data.get('balanced_precision', ''):.3f}" if data.get('balanced_precision') else '', f"{data.get('recall', ''):.3f}" if data.get('recall') else '', f"{data.get('balanced_recall', ''):.3f}" if data.get('balanced_recall') else '', f"{data.get('f1', ''):.3f}" if data.get('f1') else '', f"{data.get('specificity', ''):.3f}" if data.get('specificity') else '' ] if is_binary: row.append(f"{data.get('roc_auc', ''):.3f}" if data.get('roc_auc') else '') return row @staticmethod def _format_as_csv_string( partitions_data: Dict[str, Dict[str, Any]], n_features: int, task_type: TaskType, aggregated_partitions_data: Optional[Dict[str, Dict[str, Any]]] = None, aggregate_column: Optional[str] = None ) -> str: """Generate CSV string content. Args: partitions_data: Dict of partition name to metrics data n_features: Number of features task_type: Task type (regression or classification) aggregated_partitions_data: Optional dict of partition name to aggregated metrics aggregate_column: Name of column used for aggregation (for data annotation) Returns: CSV formatted string """ # Prepare headers based on task type if task_type == TaskType.REGRESSION: headers = ['', 'Nsample', 'Nfeature', 'Mean', 'Median', 'Min', 'Max', 'SD', 'CV', 'R2', 'RMSE', 'MSE', 'SEP', 'MAE', 'RPD', 'Bias', 'Consistency (%)', 'Aggregated'] else: # Classification is_binary = 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}) if is_binary: headers = ['', 'Nsample', 'Nfeatures', 'Accuracy', 'Bal. Acc', 'Precision', 'Bal. Prec', 'Recall', 'Bal. Rec', 'F1-score', 'Specificity', 'AUC', 'Aggregated'] else: headers = ['', 'Nsample', 'Nfeatures', 'Accuracy', 'Bal. Acc', 'Precision', 'Bal. Prec', 'Recall', 'Bal. Rec', 'F1-score', 'Specificity', 'Aggregated'] # Check if we have aggregated data has_aggregated = aggregated_partitions_data is not None and len(aggregated_partitions_data) > 0 # Prepare rows rows = [headers] # Add partition rows in order: val (Cross Val), train, test for partition_name in ['val', 'train', 'test']: if partition_name not in partitions_data: continue data = partitions_data[partition_name] display_name = "Cros Val" if partition_name == 'val' else partition_name.capitalize() # Add raw (non-aggregated) row row = TabReportManager._build_csv_row( display_name, data, n_features, task_type, 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}), is_aggregated=False ) rows.append(row) # Add aggregated row if available if has_aggregated and aggregated_partitions_data is not None and partition_name in aggregated_partitions_data: agg_data = aggregated_partitions_data[partition_name] agg_row = TabReportManager._build_csv_row( f"{display_name}*", agg_data, n_features, task_type, 'roc_auc' in partitions_data.get('val', {}) or 'roc_auc' in partitions_data.get('test', {}), is_aggregated=True, aggregate_column=aggregate_column ) rows.append(agg_row) # Generate CSV content output = io.StringIO() writer = csv.writer(output) writer.writerows(rows) # Return as string csv_content = output.getvalue() output.close() return csv_content @staticmethod def _build_csv_row( display_name: str, data: Dict[str, Any], n_features: int, task_type: TaskType, is_binary: bool = False, is_aggregated: bool = False, aggregate_column: Optional[str] = None ) -> list: """Build a single CSV row for either raw or aggregated data. Args: display_name: Row label (e.g., 'Train', 'Test', 'Train*') data: Metrics dictionary for this partition n_features: Number of features task_type: Task type (regression or classification) is_binary: Whether this is binary classification is_aggregated: Whether this is an aggregated row aggregate_column: Name of column used for aggregation Returns: List of cell values for CSV row """ aggregated_label = aggregate_column if is_aggregated and aggregate_column else '' if task_type == TaskType.REGRESSION: if is_aggregated: row = [ display_name, data.get('nsample', ''), n_features if n_features > 0 else '', '', # Mean - blank for aggregated '', # Median - blank for aggregated '', # Min - blank for aggregated '', # Max - blank for aggregated '', # SD - blank for aggregated '', # CV - blank for aggregated f"{data.get('r2', ''):.3f}" if data.get('r2') else '', f"{data.get('rmse', ''):.3f}" if data.get('rmse') else '', f"{data.get('mse', ''):.3f}" if data.get('mse') else '', f"{data.get('sep', ''):.3f}" if data.get('sep') else '', f"{data.get('mae', ''):.3f}" if data.get('mae') else '', f"{data.get('rpd', ''):.2f}" if data.get('rpd') and data.get('rpd') != float('inf') else '', f"{data.get('bias', ''):.3f}" if data.get('bias') else '', f"{data.get('consistency', ''):.1f}" if data.get('consistency') else '', aggregated_label ] else: row = [ display_name, data.get('nsample', ''), n_features if n_features > 0 else '', f"{data.get('mean', ''):.3f}" if data.get('mean') is not None else '', f"{data.get('median', ''):.3f}" if data.get('median') is not None else '', f"{data.get('min', ''):.3f}" if data.get('min') is not None else '', f"{data.get('max', ''):.3f}" if data.get('max') is not None else '', f"{data.get('sd', ''):.3f}" if data.get('sd') else '', f"{data.get('cv', ''):.3f}" if data.get('cv') else '', f"{data.get('r2', ''):.3f}" if data.get('r2') else '', f"{data.get('rmse', ''):.3f}" if data.get('rmse') else '', f"{data.get('mse', ''):.3f}" if data.get('mse') else '', f"{data.get('sep', ''):.3f}" if data.get('sep') else '', f"{data.get('mae', ''):.3f}" if data.get('mae') else '', f"{data.get('rpd', ''):.2f}" if data.get('rpd') and data.get('rpd') != float('inf') else '', f"{data.get('bias', ''):.3f}" if data.get('bias') else '', f"{data.get('consistency', ''):.1f}" if data.get('consistency') else '', '' # Not aggregated ] else: # Classification row = [ display_name, data.get('nsample', ''), n_features if n_features > 0 else '', f"{data.get('accuracy', ''):.3f}" if data.get('accuracy') else '', f"{data.get('balanced_accuracy', ''):.3f}" if data.get('balanced_accuracy') else '', f"{data.get('precision', ''):.3f}" if data.get('precision') else '', f"{data.get('balanced_precision', ''):.3f}" if data.get('balanced_precision') else '', f"{data.get('recall', ''):.3f}" if data.get('recall') else '', f"{data.get('balanced_recall', ''):.3f}" if data.get('balanced_recall') else '', f"{data.get('f1', ''):.3f}" if data.get('f1') else '', f"{data.get('specificity', ''):.3f}" if data.get('specificity') else '' ] if is_binary: row.append(f"{data.get('roc_auc', ''):.3f}" if data.get('roc_auc') else '') row.append(aggregated_label) return row @staticmethod def _calculate_partition_data( y_true: np.ndarray, y_pred: np.ndarray, task_type: str ) -> Dict[str, Any]: """Calculate metrics and statistics for a single partition.""" # Get descriptive statistics for y_true stats = evaluator.get_stats(y_true) # Get metrics based on task type if task_type.lower() == 'regression': metric_names = ['mse', 'rmse', 'mae', 'r2', 'bias', 'sep', 'rpd'] elif task_type.lower() == 'binary_classification': metric_names = ['accuracy', 'balanced_accuracy', 'precision', 'balanced_precision', 'recall', 'balanced_recall', 'f1', 'specificity', 'roc_auc'] else: # multiclass_classification metric_names = ['accuracy', 'balanced_accuracy', 'precision', 'balanced_precision', 'recall', 'balanced_recall', 'f1', 'specificity'] metrics_list = evaluator.eval_list(y_true, y_pred, metric_names) # Combine stats and metrics into a single dict partition_data = {} if stats: partition_data.update(stats) # Convert metrics list to dictionary if metrics_list and len(metrics_list) == len(metric_names): metrics_dict = dict(zip(metric_names, metrics_list)) partition_data.update(metrics_dict) # Add additional regression-specific calculations if task_type.lower() == 'regression': # Calculate consistency (percentage within 1 SD) residuals = y_pred - y_true acceptable_range = stats.get('sd', 1.0) if stats else 1.0 within_range = np.abs(residuals) <= acceptable_range partition_data['consistency'] = float(np.sum(within_range) / len(residuals) * 100) if len(residuals) > 0 else 0.0 return partition_data