Explainability Examples
This section covers model interpretation and explainability using SHAP (SHapley Additive exPlanations) for NIRS data.
Overview
Example |
Topic |
Difficulty |
Duration |
|---|---|---|---|
SHAP Basics |
★★☆☆☆ |
~5 min |
|
SHAP with sklearn |
★★☆☆☆ |
~4 min |
|
Feature Selection |
★★★☆☆ |
~4 min |
U01: SHAP Basics
Understand model predictions using SHAP analysis.
What You’ll Learn
Why SHAP for model interpretation
Running SHAP analysis with
runner.explain()Spectral, waterfall, and beeswarm visualizations
Binning and aggregation for spectral data
Why SHAP?
SHAP provides model-agnostic explanations:
What SHAP Tells You |
Description |
|---|---|
Which wavelengths matter |
Identify important spectral regions |
Direction of influence |
Positive or negative contribution |
Magnitude of effect |
Quantify importance |
Per-sample explanations |
Understand individual predictions |
SHAP Key Concepts
Concept |
Meaning |
|---|---|
SHAP value |
Contribution of each feature to the prediction |
Positive value |
Increases the prediction |
Negative value |
Decreases the prediction |
Base value |
Average model output |
Running SHAP Analysis
from nirs4all.pipeline import PipelineRunner, PipelineConfigs
from nirs4all.data import DatasetConfigs
# Train a model
runner = PipelineRunner(save_artifacts=True, verbose=0)
predictions, _ = runner.run(pipeline_config, dataset_config)
# Get best model
best = predictions.top(n=1, rank_metric='rmse')[0]
# Run SHAP analysis
shap_params = {
'n_samples': 100,
'explainer_type': 'auto',
'visualizations': ['spectral', 'waterfall']
}
shap_results, output_dir = runner.explain(
best,
dataset_config,
shap_params=shap_params,
plots_visible=True
)
SHAP Visualizations
Spectral Plot
Shows SHAP values across the spectrum:
shap_params = {
'visualizations': ['spectral'],
'bin_size': {'spectral': 10}, # Group 10 wavelengths per bin
}
X-axis: Wavelength/feature index
Y-axis: SHAP value (importance)
Identifies: Key spectral regions for prediction
Waterfall Plot
Explains a single prediction step-by-step:
shap_params = {
'visualizations': ['waterfall'],
'bin_size': {'waterfall': 20}, # Coarser for readability
}
Shows cumulative contribution from base value to final prediction
Best for understanding individual samples
Beeswarm Plot
Shows SHAP distribution for all samples:
shap_params = {
'visualizations': ['beeswarm'],
}
Each dot represents a sample
Color indicates feature value (low/high)
Reveals patterns in feature importance
Binning for Spectral Data
For high-dimensional spectral data, binning groups features:
shap_params = {
'n_samples': 200,
'visualizations': ['spectral', 'waterfall', 'beeswarm'],
# Different bin sizes per visualization
'bin_size': {
'spectral': 10, # Fine-grained overview
'waterfall': 20, # Fewer bars for readability
'beeswarm': 20 # Medium granularity
},
# Stride (overlap) control
'bin_stride': {
'spectral': 5, # 50% overlap
'waterfall': 10, # 50% overlap
'beeswarm': 20 # No overlap
},
# How to aggregate SHAP values in bins
'bin_aggregation': {
'spectral': 'mean', # Average importance
'waterfall': 'mean',
'beeswarm': 'mean'
}
}
Explainer Types
Type |
Models |
Speed |
|---|---|---|
|
Auto-detect best explainer |
Varies |
|
Tree-based (RF, GBR) |
Fast |
|
Linear models (PLS, Ridge) |
Fast |
|
Universal (any model) |
Slow |
|
Neural networks |
Medium |
U02: SHAP with sklearn Wrapper
Use SHAP with sklearn-wrapped NIRS4ALL models.
What You’ll Learn
Direct SHAP with sklearn wrapper
Custom explainer configuration
Integration with SHAP library
Direct SHAP Usage
import shap
from nirs4all.sklearn import SklearnWrapper
# Wrap trained model
wrapper = SklearnWrapper(prediction_entry=best)
# Create SHAP explainer
explainer = shap.Explainer(wrapper.predict, X_train)
# Calculate SHAP values
shap_values = explainer(X_test[:100])
# Visualizations
shap.plots.beeswarm(shap_values)
shap.plots.waterfall(shap_values[0])
Feature Names for Spectral Data
# Create meaningful feature names
wavelengths = np.linspace(1000, 2500, X.shape[1])
feature_names = [f"{w:.0f}nm" for w in wavelengths]
# Use with SHAP
shap.plots.beeswarm(shap_values, feature_names=feature_names)
U03: Feature Selection
Use SHAP importance for feature selection.
What You’ll Learn
SHAP-based feature ranking
Selecting top wavelengths
Validation with reduced features
SHAP Feature Importance
import numpy as np
# Get absolute SHAP values
shap_importance = np.abs(shap_values.values).mean(axis=0)
# Rank features
ranking = np.argsort(shap_importance)[::-1]
top_features = ranking[:50] # Top 50 wavelengths
print(f"Top wavelengths: {wavelengths[top_features]}")
Feature Selection Pipeline
# Select top features based on SHAP
def select_top_features(X, shap_values, n_features=50):
importance = np.abs(shap_values.values).mean(axis=0)
top_idx = np.argsort(importance)[::-1][:n_features]
return X[:, top_idx], top_idx
# Apply to data
X_selected, selected_idx = select_top_features(X, shap_values, n_features=50)
# Train on selected features
result_selected = nirs4all.run(
pipeline=[PLSRegression(n_components=10)],
dataset=(X_selected, y, {"train": 160}),
name="SelectedFeatures"
)
Wavelength Region Analysis
Identify important spectral regions:
# Group SHAP values by wavelength regions
regions = {
"1000-1300nm": (0, 100),
"1300-1600nm": (100, 200),
"1600-1900nm": (200, 300),
# ...
}
region_importance = {}
for name, (start, end) in regions.items():
region_importance[name] = np.abs(shap_values.values[:, start:end]).mean()
# Sort by importance
sorted_regions = sorted(region_importance.items(), key=lambda x: x[1], reverse=True)
for region, importance in sorted_regions:
print(f"{region}: {importance:.4f}")
Explainability Best Practices
1. Use Sufficient Samples
# More samples = more reliable SHAP values
shap_params = {
'n_samples': 200, # At least 100-200 samples
}
2. Validate Explanations
Check if SHAP values make chemical sense:
# If predicting sugar content, expect importance at:
# - ~1400-1500 nm (O-H bonds in sugars)
# - ~2100-2300 nm (C-H bonds)
3. Compare Models
# Compare SHAP patterns across different models
for model_name in ['PLS', 'RF', 'Ridge']:
model_pred = predictions.filter(model_name=model_name).top(1)[0]
shap_results, _ = runner.explain(model_pred, dataset_config, shap_params)
# Compare spectral patterns
4. Document Findings
# Include SHAP insights in model documentation
metadata = {
"important_regions": {
"1400-1500nm": "O-H overtones",
"2100-2300nm": "C-H combinations"
},
"shap_validation": "Patterns consistent with known sugar absorption"
}
SHAP Parameter Reference
Parameter |
Description |
Default |
|---|---|---|
|
Samples to explain |
100 |
|
|
|
|
List of plots: |
All |
|
Features per bin (int or dict per viz) |
10 |
|
Step between bins |
Same as bin_size |
|
|
|
Running These Examples
cd examples
# Run all explainability examples
./run.sh -n "U0*.py" -c user
# Run with plots
python user/07_explainability/U01_shap_basics.py --plots --show
Requirements
SHAP analysis requires the shap library:
pip install shap
Next Steps
After mastering explainability:
Developer Examples: Advanced pipelines, deep learning
Transfer Learning: Adapt models to new instruments
Custom Controllers: Extend NIRS4ALL functionality