nirs4all.visualization.analysis package
Submodules
- nirs4all.visualization.analysis.branch module
- nirs4all.visualization.analysis.shap module
ShapAnalyzerShapAnalyzer.explain_model()ShapAnalyzer.get_feature_importance()ShapAnalyzer.load_results()ShapAnalyzer.plot_beeswarm()ShapAnalyzer.plot_beeswarm_binned()ShapAnalyzer.plot_dependence()ShapAnalyzer.plot_force()ShapAnalyzer.plot_spectral_importance()ShapAnalyzer.plot_summary()ShapAnalyzer.plot_waterfall()ShapAnalyzer.plot_waterfall_binned()ShapAnalyzer.save_results()
- nirs4all.visualization.analysis.transfer module
PreprocPCAEvaluatorPreprocPCAEvaluator.fit()PreprocPCAEvaluator.get_cross_dataset_summary()PreprocPCAEvaluator.get_quality_metric_convergence()PreprocPCAEvaluator.plot_all()PreprocPCAEvaluator.plot_all_datasets_pca()PreprocPCAEvaluator.plot_cross_dataset_distances()PreprocPCAEvaluator.plot_cross_dataset_heatmap()PreprocPCAEvaluator.plot_distance_matrices()PreprocPCAEvaluator.plot_distance_reduction_ranking()PreprocPCAEvaluator.plot_pair()PreprocPCAEvaluator.plot_preservation_summary()PreprocPCAEvaluator.plot_quality_metric_convergence()
Module contents
Analysis utilities for visualization.
- class nirs4all.visualization.analysis.BranchAnalyzer(predictions)[source]
Bases:
objectAnalyze and compare performance across pipeline branches.
Provides statistical analysis, hypothesis testing, and comparison tools for branched pipeline results.
- predictions
Predictions object containing prediction data.
- compare(branch1: str | int, branch2: str | int, metric: str = 'rmse', partition: str = 'test', test: str = 'ttest') Dict[str, Any][source]
Statistical comparison between two branches.
Performs hypothesis testing to determine if there’s a significant difference between two branches.
- Parameters:
branch1 – First branch name or ID.
branch2 – Second branch name or ID.
metric – Metric to compare (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
test – Statistical test (‘ttest’, ‘wilcoxon’, ‘mannwhitney’).
- Returns:
statistic: Test statistic
p_value: P-value
significant: Boolean at alpha=0.05
branch1_mean: Mean of branch1
branch2_mean: Mean of branch2
effect_size: Cohen’s d effect size
- Return type:
Dictionary with
- Raises:
ImportError – If scipy is not available.
ValueError – If branches not found or insufficient data.
- get_branch_names() List[str][source]
Get list of unique branch names.
- Returns:
List of branch names.
- pairwise_comparison(metric: str = 'rmse', partition: str = 'test', test: str = 'ttest') DataFrame[source]
Compute pairwise statistical comparisons between all branches.
- Parameters:
metric – Metric to compare (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
test – Statistical test to use.
- Returns:
DataFrame with p-values for all branch pairs.
- Raises:
ImportError – If pandas or scipy not available.
- rank_branches(metric: str = 'rmse', partition: str = 'test', ascending: bool | None = None) List[Dict[str, Any]][source]
Rank branches by mean performance.
- Parameters:
metric – Metric to rank by (default: ‘rmse’).
partition – Partition for scores (default: ‘test’).
ascending – Sort order. If None, auto-detect based on metric.
- Returns:
List of dicts with branch_name, mean, std, rank.
- summary(metrics: List[str] | None = None, partition: str = 'test', aggregate: str | None = None) BranchSummary[source]
Generate summary statistics for each branch.
Computes mean, std, min, max for each metric across branches.
- Parameters:
metrics – List of metrics to compute (default: [‘rmse’, ‘r2’]).
partition – Partition to compute metrics from (default: ‘test’).
aggregate – If provided, aggregate predictions by this column before computing statistics.
- Returns:
BranchSummary object with statistics.
- class nirs4all.visualization.analysis.BranchSummary(data: List[Dict[str, Any]], metrics: List[str])[source]
Bases:
objectBranch summary statistics container with export capabilities.
Provides DataFrame-like access and export to markdown, LaTeX, and CSV.
- data
List of dictionaries with branch statistics.
- metrics
List of metrics computed.
- columns
Column names in order.
- __getitem__(key: int | str) Dict[str, Any][source]
Get branch by index or name.
- Parameters:
key – Integer index or branch name string.
- Returns:
Dictionary with branch statistics.
- to_csv(path: str, precision: int = 6) None[source]
Export to CSV file.
- Parameters:
path – Output file path.
precision – Decimal places for floating point values.
- to_dataframe() DataFrame[source]
Convert to pandas DataFrame.
- Returns:
pandas DataFrame with branch statistics.
- Raises:
ImportError – If pandas is not installed.
- to_dict() Dict[str, Dict[str, Any]][source]
Convert to dictionary keyed by branch name.
- Returns:
Dictionary mapping branch_name to statistics.
- to_latex(caption: str = 'Branch Performance Comparison', label: str = 'tab:branch_comparison', precision: int = 3, include_std: bool = True, mean_std_combined: bool = True) str[source]
Export as LaTeX table for publications.
- Parameters:
caption – Table caption.
label – LaTeX label for referencing.
precision – Decimal places for floating point values.
include_std – If True, include std columns.
mean_std_combined – If True, format as “mean ± std”.
- Returns:
LaTeX-formatted table string.
- class nirs4all.visualization.analysis.PreprocPCAEvaluator(r_components=10, knn=10)[source]
Bases:
object- fit(raw_data: dict[str, ndarray], pp_data: dict[str, dict[str, ndarray]])[source]
raw_data: {“dataset”: X_raw_(n,m), …} pp_data: Can be either:
{“pp_name”: {“dataset”: X_pp_(n,p), …}, …} OR
{“dataset”: {“pp_name”: X_pp_(n,p), …}, …}
(will automatically detect and pivot if needed)
Assumes rows (samples) are aligned within each dataset across raw and pp.
- get_cross_dataset_summary(metric='centroid_improvement')[source]
Get a summary of how preprocessing affects inter-dataset distances.
- Parameters:
metric – ‘centroid_improvement’ or ‘spread_improvement’ Higher values = preprocessing brought datasets closer
- Returns:
DataFrame sorted by improvement (best preprocessing first)
- get_quality_metric_convergence()[source]
Analyze how preprocessing affects the similarity of quality metrics across datasets. Lower variance = preprocessing makes datasets more homogeneous in quality.
- Returns:
DataFrame with variance of quality metrics (evr, cka, rv, etc.) across datasets for raw vs preprocessed data. Lower values = better convergence.
- plot_all_datasets_pca(figsize=(16, 12))[source]
Plot all datasets together in the same PCA space for raw and each preprocessing. Shows how datasets cluster and separate in different preprocessing spaces.
- plot_cross_dataset_distances(figsize=(14, 8))[source]
Plot how preprocessing affects inter-dataset distances. Shows which preprocessing methods bring datasets closer together.
- plot_cross_dataset_heatmap(metric='centroid_improvement', figsize=(12, 10))[source]
Create a heatmap showing pairwise dataset distances for each preprocessing.
- Parameters:
metric – ‘centroid_improvement’, ‘centroid_dist_pp’, ‘spread_improvement’, or ‘spread_dist_pp’
- plot_distance_matrices(metric='centroid', figsize=(18, 12))[source]
Plot distance matrices showing inter-dataset distances for raw and all preprocessings. Shows which preprocessing reduces distances (better for transfer learning).
- Parameters:
metric – ‘centroid’ or ‘spread’ - which distance metric to display
- plot_distance_reduction_ranking(metric='centroid', log_scale=False, figsize=(14, 8))[source]
Bar chart showing which preprocessing methods best reduce inter-dataset distances. Directly answers: “Which preprocessing is best for transfer learning?”
- Parameters:
metric – ‘centroid’ or ‘spread’ - which distance method to use for ranking
log_scale – If True, use log scale for the right plot (absolute distances) to handle extreme values
- plot_pair(dataset: str, preproc: str, figsize=(10, 5))[source]
Enhanced comparison plot for a specific dataset-preprocessing pair.
- plot_preservation_summary(by='preproc', figsize=(14, 8))[source]
Enhanced summary plot with better styling.
- plot_quality_metric_convergence(figsize=(16, 10))[source]
Visualize how preprocessing makes quality metrics more homogeneous across datasets. Shows variance reduction in EVR, CKA, RV, Procrustes, Trustworthiness, Grassmann.
Lower variance after preprocessing = datasets behave more similarly = better for transfer learning.
- class nirs4all.visualization.analysis.ShapAnalyzer[source]
Bases:
objectSHAP-based model explainability analyzer for NIRS models.
Provides explanations showing which wavelengths/features are most important for model predictions, with specialized visualizations for spectral data.
- explain_model(model: Any, X: ndarray, y: ndarray | None = None, feature_names: List[str] | None = None, sample_indices: List[int] | None = None, task_type: str = 'regression', n_background: int = 100, explainer_type: str = 'auto', output_dir: str | None = None, visualizations: List[str] | None = None, bin_size=20, bin_stride=10, bin_aggregation='sum', plots_visible=True) Dict[str, Any][source]
Explain model predictions using SHAP values.
- Parameters:
model – Trained model to explain
X – Input features (samples x features)
y – Target values (optional, for reference)
feature_names – Names of features/wavelengths
sample_indices – Specific samples to explain (None = all)
task_type – ‘regression’ or ‘classification’
n_background – Number of background samples for KernelExplainer
explainer_type – ‘auto’, ‘tree’, ‘deep’, ‘kernel’, ‘linear’
output_dir – Directory to save visualizations
visualizations – List of viz types to generate
bin_size – Number of wavelengths per bin. Can be: - int: same for all visualizations - dict: {‘spectral’: 20, ‘waterfall’: 30, ‘beeswarm’: 50}
bin_stride – Step size between bins. Can be: - int: same for all visualizations - dict: {‘spectral’: 10, ‘waterfall’: 15, ‘beeswarm’: 25}
bin_aggregation – Aggregation method. Can be: - str: same for all (‘sum’, ‘sum_abs’, ‘mean’, ‘mean_abs’) - dict: {‘spectral’: ‘sum’, ‘waterfall’: ‘mean’, ‘beeswarm’: ‘sum_abs’}
- Returns:
Dictionary with SHAP results
- get_feature_importance(top_n: int | None = None) Dict[str, float][source]
Get feature importance ranking based on mean absolute SHAP values.
- Parameters:
top_n – Return only top N features (None = all)
- Returns:
Dictionary mapping feature index to importance score
- static load_results(input_path: str) Dict[str, Any][source]
Load SHAP results from disk using the new serializer.
- plot_beeswarm(feature_names: List[str] | None = None, output_path: str | None = None, max_display: int = 20, plots_visible: bool = True)[source]
Create SHAP beeswarm plot.
- plot_beeswarm_binned(output_path: str | None = None, max_display: int = 20, plots_visible: bool = True)[source]
Create SHAP beeswarm plot with binned features.
Bins wavelengths/features according to bin_size and bin_stride parameters, then displays beeswarm plot for aggregated SHAP values.
- plot_dependence(feature_idx: int, feature_names: List[str] | None = None, output_path: str | None = None, interaction_index: int | None = None, plots_visible: bool = True)[source]
Create SHAP dependence plot for a specific feature.
- plot_force(sample_idx: int = 0, feature_names: List[str] | None = None, output_path: str | None = None, plots_visible: bool = True)[source]
Create SHAP force plot for a single sample.
- plot_spectral_importance(feature_names: List[str] | None = None, output_path: str | None = None, figsize: Tuple[int, int] = (16, 10), plots_visible: bool = True)[source]
Create NIRS-specific spectral importance visualization with binned regions.
Shows important spectral regions (not individual wavelengths) by binning wavelengths and aggregating SHAP values. This is more robust and meaningful for NIRS analysis than point-by-point importance.
Uses self.bin_size, self.bin_stride, and self.bin_aggregation configured in explain_model().
- plot_summary(feature_names: List[str] | None = None, output_path: str | None = None, max_display: int = 20, plots_visible: bool = True)[source]
Create SHAP summary plot showing feature importance.
- plot_waterfall(sample_idx: int = 0, feature_names: List[str] | None = None, output_path: str | None = None, max_display: int = 20, plots_visible: bool = True)[source]
Create SHAP waterfall plot for a single sample.
- plot_waterfall_binned(sample_idx: int = 0, output_path: str | None = None, max_display: int = 20, plots_visible: bool = True)[source]
Create SHAP waterfall plot with binned features for a single sample.
Bins wavelengths/features according to bin_size and bin_stride parameters, then displays waterfall plot for aggregated SHAP values.