# pip install numpy pandas scipy scikit-learn matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from nirs4all.core.logging import get_logger
logger = get_logger(__name__)
from matplotlib.patches import FancyBboxPatch
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from scipy.linalg import subspace_angles
from scipy.spatial import procrustes
from scipy.spatial.distance import cdist
[docs]
class PreprocPCAEvaluator:
def __init__(self, r_components=10, knn=10):
self.r = r_components
self.knn = knn
self.df_ = None
self.cache_ = {}
self.raw_pcas_ = {} # Store raw PCA results for visualization
self.cross_dataset_df_ = None # Store inter-dataset distance metrics
self.pp_pcas_ = {} # Store preprocessed PCA results: {(dataset, preproc): (Z, U, evr)}
# ---------------- utils ----------------
@staticmethod
def _center(X): return X - X.mean(0, keepdims=True)
def _pca(self, X, r):
Xc = self._center(X)
r = min(r, Xc.shape[1])
p = PCA(n_components=r, random_state=0).fit(Xc)
Z, U = p.transform(Xc), p.components_.T
evr = float(p.explained_variance_ratio_.sum())
return Z, U, evr
@staticmethod
def _grassmann(U, V):
th = subspace_angles(U, V)
return float(np.sqrt((th**2).sum()))
def _cka(self, X, Y):
Xc, Yc = self._center(X), self._center(Y)
hsic = np.linalg.norm(Xc.T @ Yc, 'fro')**2
den = np.linalg.norm(Xc.T @ Xc, 'fro') * np.linalg.norm(Yc.T @ Yc, 'fro')
return float(hsic/den) if den>0 else np.nan
def _rv(self, X, Y):
Xc, Yc = self._center(X), self._center(Y)
A, B = Xc @ Xc.T, Yc @ Yc.T
num = np.trace(A @ B)
den = np.sqrt(np.trace(A @ A) * np.trace(B @ B))
return float(num/den) if den>0 else np.nan
@staticmethod
def _procrustes(Z1, Z2):
_, _, d = procrustes(Z1, Z2)
return float(d)
def _trust(self, Zref, Znew, k):
n = Zref.shape[0]
k = max(2, min(k, n-2))
nnr = NearestNeighbors(n_neighbors=n-1).fit(Zref).kneighbors(return_distance=False)
nnn = NearestNeighbors(n_neighbors=n-1).fit(Znew).kneighbors(return_distance=False)
# ranks[i, j] = rank of sample j in the neighborhood of sample i in reference space
ranks = np.zeros((n, n), dtype=int)
for i in range(n):
ranks[i, nnr[i]] = np.arange(n-1)
s = 0.0
for i in range(n):
Ui = set(nnn[i, 1:1+k])
Ki = set(nnr[i, 1:1+k])
for v in Ui - Ki:
s += (ranks[i, v] - (k-1))
Z = n*k*(2*n - 3*k - 1)/2
return 1.0 - (2.0/Z)*s if Z>0 else np.nan
# ---------------- core API ----------------
[docs]
def fit(self, raw_data: dict[str, np.ndarray], pp_data: dict[str, dict[str, np.ndarray]]):
"""
raw_data: {"dataset": X_raw_(n,m), ...}
pp_data: Can be either:
- {"pp_name": {"dataset": X_pp_(n,p), ...}, ...} OR
- {"dataset": {"pp_name": X_pp_(n,p), ...}, ...}
(will automatically detect and pivot if needed)
Assumes rows (samples) are aligned within each dataset across raw and pp.
"""
# Auto-detect structure and pivot if needed
pp_data = self._ensure_pp_structure(pp_data, raw_data)
rows = []
self.cache_.clear()
self.raw_pcas_.clear()
# precompute raw PCA per dataset
for dname, Xr in raw_data.items():
Zr, Ur, evr_r = self._pca(np.asarray(Xr), self.r)
self.raw_pcas_[dname] = (Zr, Ur, evr_r)
# iterate preprocessings
for pp_name, dmap in pp_data.items():
for dname, Xp in dmap.items():
if dname not in self.raw_pcas_:
continue # skip if no matching raw dataset
Zr, Ur, evr_r = self.raw_pcas_[dname]
Xp = np.asarray(Xp)
if Xp.shape[0] != Zr.shape[0]:
raise ValueError(f"n_samples mismatch for dataset '{dname}' in '{pp_name}'")
Zp, Up, evr_p = self._pca(Xp, min(self.r, Zr.shape[1]))
r_use = min(Ur.shape[1], Up.shape[1], Zr.shape[1], Zp.shape[1])
Ur_, Up_ = Ur[:, :r_use], Up[:, :r_use]
Zr_, Zp_ = Zr[:, :r_use], Zp[:, :r_use]
# Grassmann distance only makes sense when feature spaces have same dimensionality
# If preprocessing changes feature dimension, we skip it (set to NaN)
grassmann_dist = np.nan
if Ur_.shape[0] == Up_.shape[0]: # same number of features
grassmann_dist = self._grassmann(Ur_, Up_)
rows.append({
"dataset": dname,
"preproc": pp_name,
"r_used": r_use,
"evr_raw": evr_r,
"evr_pre": evr_p,
"grassmann": grassmann_dist,
"cka": self._cka(Zr_, Zp_),
"rv": self._rv(Zr_, Zp_),
"procrustes": self._procrustes(Zr_, Zp_),
"trustworthiness": self._trust(Zr_, Zp_, k=self.knn),
})
# cache full PCA scores for visualization
self.cache_[(dname, pp_name)] = (Zr_, Zp_)
# Store preprocessed PCA for cross-dataset analysis
self.pp_pcas_[(dname, pp_name)] = (Zp, Up, evr_p)
self.df_ = pd.DataFrame(rows)
# Compute inter-dataset distances
self._compute_cross_dataset_distances(raw_data, pp_data)
return self
def _ensure_pp_structure(self, pp_data, raw_data):
"""
Ensure pp_data has structure {preproc: {dataset: X}}.
If it's {dataset: {preproc: X}}, pivot it.
"""
if not pp_data:
return pp_data
# Check first key to determine structure
first_key = next(iter(pp_data.keys()))
first_val = pp_data[first_key]
# If first value is a dict and its keys match raw_data keys, it's {preproc: {dataset: X}}
if isinstance(first_val, dict):
first_inner_key = next(iter(first_val.keys()))
if first_inner_key in raw_data:
# Already correct structure {preproc: {dataset: X}}
return pp_data
elif first_key in raw_data:
# Wrong structure {dataset: {preproc: X}}, need to pivot
pivoted = {}
for dataset_name, preproc_map in pp_data.items():
for preproc_name, X in preproc_map.items():
if preproc_name not in pivoted:
pivoted[preproc_name] = {}
pivoted[preproc_name][dataset_name] = X
return pivoted
return pp_data
# ---------------- cross-dataset analysis ----------------
def _compute_cross_dataset_distances(self, raw_data, pp_data):
"""
Compute pairwise distances between different datasets in PCA space,
both for raw and preprocessed data. This helps assess if preprocessing
brings datasets (e.g., from different machines) closer together.
"""
dataset_names = list(raw_data.keys())
if len(dataset_names) < 2:
self.cross_dataset_df_ = pd.DataFrame()
return
preproc_names = list(pp_data.keys())
rows = []
# Compute distances for all dataset pairs
for i, ds1 in enumerate(dataset_names):
for j in range(i + 1, len(dataset_names)):
ds2 = dataset_names[j]
# Raw data distances
if ds1 in self.raw_pcas_ and ds2 in self.raw_pcas_:
Z1_raw, U1_raw, _ = self.raw_pcas_[ds1]
Z2_raw, U2_raw, _ = self.raw_pcas_[ds2]
# Use minimum components available
r_use = min(Z1_raw.shape[1], Z2_raw.shape[1])
Z1_raw = Z1_raw[:, :r_use]
Z2_raw = Z2_raw[:, :r_use]
# Compute centroid distance (how far apart are the dataset centers?)
centroid_dist_raw = np.linalg.norm(Z1_raw.mean(axis=0) - Z2_raw.mean(axis=0))
# Compute spread overlap (how much do distributions overlap?)
# Using Wasserstein/Earth Mover's distance approximation
spread_dist_raw = self._compute_spread_distance(Z1_raw, Z2_raw)
# Subspace angle (how aligned are the PCA subspaces?)
r_subspace = min(U1_raw.shape[1], U2_raw.shape[1], U1_raw.shape[0], U2_raw.shape[0])
if r_subspace > 0 and U1_raw.shape[0] == U2_raw.shape[0]:
subspace_angle_raw = self._grassmann(U1_raw[:, :r_subspace], U2_raw[:, :r_subspace])
else:
subspace_angle_raw = np.nan
# For each preprocessing method
for pp_name in preproc_names:
if (ds1, pp_name) in self.pp_pcas_ and (ds2, pp_name) in self.pp_pcas_:
Z1_pp, U1_pp, _ = self.pp_pcas_[(ds1, pp_name)]
Z2_pp, U2_pp, _ = self.pp_pcas_[(ds2, pp_name)]
r_use_pp = min(Z1_pp.shape[1], Z2_pp.shape[1])
Z1_pp = Z1_pp[:, :r_use_pp]
Z2_pp = Z2_pp[:, :r_use_pp]
centroid_dist_pp = np.linalg.norm(Z1_pp.mean(axis=0) - Z2_pp.mean(axis=0))
spread_dist_pp = self._compute_spread_distance(Z1_pp, Z2_pp)
r_subspace_pp = min(U1_pp.shape[1], U2_pp.shape[1], U1_pp.shape[0], U2_pp.shape[0])
if r_subspace_pp > 0 and U1_pp.shape[0] == U2_pp.shape[0]:
subspace_angle_pp = self._grassmann(U1_pp[:, :r_subspace_pp], U2_pp[:, :r_subspace_pp])
else:
subspace_angle_pp = np.nan
# Compute improvement (negative = datasets got closer)
centroid_improvement = (centroid_dist_raw - centroid_dist_pp) / (centroid_dist_raw + 1e-10)
spread_improvement = (spread_dist_raw - spread_dist_pp) / (spread_dist_raw + 1e-10)
rows.append({
'dataset_1': ds1,
'dataset_2': ds2,
'preproc': pp_name,
'centroid_dist_raw': centroid_dist_raw,
'centroid_dist_pp': centroid_dist_pp,
'centroid_improvement': centroid_improvement,
'spread_dist_raw': spread_dist_raw,
'spread_dist_pp': spread_dist_pp,
'spread_improvement': spread_improvement,
'subspace_angle_raw': subspace_angle_raw,
'subspace_angle_pp': subspace_angle_pp,
})
self.cross_dataset_df_ = pd.DataFrame(rows)
def _compute_spread_distance(self, Z1, Z2):
"""
Compute a distance metric between the spreads of two datasets in PCA space.
Uses a combination of covariance distance and mean pairwise distance.
"""
# Covariance-based distance (how different are the shapes?)
cov1 = np.cov(Z1.T)
cov2 = np.cov(Z2.T)
cov_dist = np.linalg.norm(cov1 - cov2, 'fro')
# Sample-wise distance (average minimum distance between samples)
# Take a sample to avoid O(n^2) computation
n_samples = min(100, Z1.shape[0], Z2.shape[0])
idx1 = np.random.choice(Z1.shape[0], n_samples, replace=False)
idx2 = np.random.choice(Z2.shape[0], n_samples, replace=False)
Z1_sample = Z1[idx1]
Z2_sample = Z2[idx2]
# Compute minimum distances
dists = cdist(Z1_sample, Z2_sample, metric='euclidean')
min_dist = np.mean(np.minimum(dists.min(axis=0), dists.min(axis=1)))
# Combine both metrics
return float(cov_dist + min_dist)
[docs]
def get_cross_dataset_summary(self, metric='centroid_improvement'):
"""
Get a summary of how preprocessing affects inter-dataset distances.
Args:
metric: 'centroid_improvement' or 'spread_improvement'
Higher values = preprocessing brought datasets closer
Returns:
DataFrame sorted by improvement (best preprocessing first)
"""
if self.cross_dataset_df_ is None or self.cross_dataset_df_.empty:
raise ValueError("Run fit() first with multiple datasets.")
# Aggregate across dataset pairs
summary = self.cross_dataset_df_.groupby('preproc').agg({
'centroid_improvement': ['mean', 'std'],
'spread_improvement': ['mean', 'std'],
'centroid_dist_pp': 'mean',
'spread_dist_pp': 'mean',
}).reset_index()
summary.columns = ['preproc', 'centroid_improv_mean', 'centroid_improv_std',
'spread_improv_mean', 'spread_improv_std',
'centroid_dist_pp', 'spread_dist_pp']
# Sort by the requested metric
sort_col = metric.replace('improvement', 'improv_mean')
if sort_col in summary.columns:
summary = summary.sort_values(sort_col, ascending=False)
return summary
[docs]
def get_quality_metric_convergence(self):
"""
Analyze how preprocessing affects the similarity of quality metrics across datasets.
Lower variance = preprocessing makes datasets more homogeneous in quality.
Returns:
DataFrame with variance of quality metrics (evr, cka, rv, etc.) across datasets
for raw vs preprocessed data. Lower values = better convergence.
"""
if self.df_ is None or self.df_.empty:
raise ValueError("Run fit() first.")
quality_metrics = ['evr_pre', 'cka', 'rv', 'procrustes', 'trustworthiness', 'grassmann']
# For raw data, compute variance across datasets (using evr_raw as proxy)
datasets = self.df_['dataset'].unique()
raw_variance = {}
for metric in quality_metrics:
if metric == 'evr_pre':
# For raw, use evr_raw
raw_vals = [self.df_[self.df_['dataset'] == ds]['evr_raw'].iloc[0]
for ds in datasets]
else:
# For other metrics, we need to compare raw PCA structures
# Use average across all preprocessings as approximation
raw_vals = [self.df_[self.df_['dataset'] == ds][metric].mean()
for ds in datasets]
# Invert distance metrics (grassmann, procrustes) so higher = better quality
# This way variance measures homogeneity in "goodness" not in "distance values"
if metric in ['grassmann', 'procrustes']:
raw_vals = [-v for v in raw_vals]
raw_variance[metric] = float(np.nanvar(raw_vals))
# For each preprocessing, compute variance across datasets
results = []
for preproc in self.df_['preproc'].unique():
df_pp = self.df_[self.df_['preproc'] == preproc]
row = {'preproc': preproc}
for metric in quality_metrics:
pp_vals = df_pp[metric].values
# Invert distance metrics (grassmann, procrustes) so higher = better quality
if metric in ['grassmann', 'procrustes']:
pp_vals = -pp_vals
pp_variance = float(np.nanvar(pp_vals))
# Convergence = reduction in variance (positive = better)
convergence = (raw_variance[metric] - pp_variance) / (raw_variance[metric] + 1e-10)
row[f'{metric}_var_raw'] = raw_variance[metric]
row[f'{metric}_var_pp'] = pp_variance
row[f'{metric}_convergence'] = convergence
results.append(row)
return pd.DataFrame(results)
# ---------------- plots ----------------
[docs]
def plot_all_datasets_pca(self, figsize=(16, 12)):
"""
Plot all datasets together in the same PCA space for raw and each preprocessing.
Shows how datasets cluster and separate in different preprocessing spaces.
"""
if self.df_ is None or self.df_.empty:
raise ValueError("Run fit() first.")
datasets = list(self.raw_pcas_.keys())
preprocs = sorted(self.df_['preproc'].unique())
# More contrastive color palette for datasets
contrastive_colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
'#ffff33', '#a65628', '#f781bf', '#999999', '#66c2a5']
dataset_colors = {ds: contrastive_colors[i % len(contrastive_colors)]
for i, ds in enumerate(datasets)}
# Determine grid layout: raw + all preprocessings
n_plots = 1 + len(preprocs)
n_cols = min(4, n_plots)
n_rows = (n_plots + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = axes.flatten() if n_plots > 1 else [axes]
# Plot 1: Raw data - all datasets together
ax = axes[0]
for ds_idx, dname in enumerate(datasets):
if dname in self.raw_pcas_:
Zr, _, evr = self.raw_pcas_[dname]
ax.scatter(Zr[:, 0], Zr[:, 1], alpha=0.5, s=25,
c=dataset_colors[dname], label=dname,
edgecolors='white', linewidth=0.3)
ax.set_xlabel('Principal Component 1 (PC1)', fontsize=10, fontweight='bold')
ax.set_ylabel('Principal Component 2 (PC2)', fontsize=10, fontweight='bold')
ax.set_title('RAW DATA\nAll Datasets', fontsize=11, fontweight='bold', pad=10)
# Legend in top-left corner with 3px offset
ax.legend(loc='upper left', fontsize=8, framealpha=0.9, bbox_to_anchor=(0.01, 0.99))
ax.grid(alpha=0.3, linestyle='--')
ax.set_facecolor('#f8f9fa')
# Plot preprocessings
for pp_idx, pp_name in enumerate(preprocs):
ax = axes[pp_idx + 1]
for dname in datasets:
if (dname, pp_name) in self.pp_pcas_:
Zp, _, _ = self.pp_pcas_[(dname, pp_name)]
ax.scatter(Zp[:, 0], Zp[:, 1], alpha=0.5, s=25,
c=dataset_colors[dname], label=dname,
edgecolors='white', linewidth=0.3)
# Format preprocessing name
pp_display = pp_name.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(pp_display) > 30:
pp_display = pp_display[:22] + '...'
ax.set_xlabel('Principal Component 1 (PC1)', fontsize=10, fontweight='bold')
ax.set_ylabel('Principal Component 2 (PC2)', fontsize=10, fontweight='bold')
ax.set_title(f'{pp_display}', fontsize=9, fontweight='bold', pad=10)
# No legend on preprocessing plots (only on raw)
ax.grid(alpha=0.3, linestyle='--')
ax.set_facecolor('#f8f9fa')
# Hide unused subplots
for idx in range(n_plots, len(axes)):
axes[idx].axis('off')
plt.suptitle('Dataset Clustering in Different Preprocessing Spaces\n(Closer clusters = better for transfer learning)',
fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
return fig
[docs]
def plot_distance_matrices(self, metric='centroid', figsize=(18, 12)):
"""
Plot distance matrices showing inter-dataset distances for raw and all preprocessings.
Shows which preprocessing reduces distances (better for transfer learning).
Args:
metric: 'centroid' or 'spread' - which distance metric to display
"""
if self.cross_dataset_df_ is None or self.cross_dataset_df_.empty:
logger.warning("No cross-dataset analysis available. Need multiple datasets.")
return None
datasets = sorted(set(self.cross_dataset_df_['dataset_1']).union(
set(self.cross_dataset_df_['dataset_2'])))
preprocs = sorted(self.cross_dataset_df_['preproc'].unique())
# Select metric columns
metric_col_raw = f'{metric}_dist_raw'
metric_col_pp = f'{metric}_dist_pp'
# Define metric name with computation method
if metric == 'centroid':
metric_display = 'Centroid Distance (Euclidean L2-norm)'
elif metric == 'spread':
metric_display = 'Spread Distance (Frobenius + Sample-wise)'
else:
metric_display = metric.capitalize() + ' Distance'
n_datasets = len(datasets)
n_preprocs = len(preprocs)
# Create matrices for raw and each preprocessing
n_plots = 1 + n_preprocs
n_cols = min(4, n_plots)
n_rows = (n_plots + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = axes.flatten() if n_plots > 1 else [axes]
# Plot 1: Raw distances
ax = axes[0]
raw_matrix = np.zeros((n_datasets, n_datasets))
for _, row in self.cross_dataset_df_.iterrows():
i = datasets.index(row['dataset_1'])
j = datasets.index(row['dataset_2'])
val = row[metric_col_raw]
raw_matrix[i, j] = val
raw_matrix[j, i] = val
# Use YlOrRd colormap for better text readability (light to dark)
im = ax.imshow(raw_matrix, cmap='YlOrRd', aspect='auto')
ax.set_title(f'RAW DATA\n{metric_display}', fontsize=10, fontweight='bold')
ax.set_xticks(range(n_datasets))
ax.set_yticks(range(n_datasets))
ax.set_xticklabels(datasets, rotation=45, ha='right', fontsize=8)
ax.set_yticklabels(datasets, fontsize=8)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Add text annotations with adaptive color (white for dark cells)
vmax = raw_matrix.max()
for i in range(n_datasets):
for j in range(n_datasets):
if i != j:
val = raw_matrix[i, j]
text_color = 'white' if val > 0.8 * vmax else 'black'
# Use scientific notation if value is very small
val_str = f'{val:.2e}' if val < 0.001 else f'{val:.3f}'
ax.text(j, i, val_str,
ha="center", va="center", color=text_color, fontsize=7, fontweight='bold')
# Plot preprocessed distances
for pp_idx, pp_name in enumerate(preprocs):
ax = axes[pp_idx + 1]
pp_matrix = np.zeros((n_datasets, n_datasets))
df_pp = self.cross_dataset_df_[self.cross_dataset_df_['preproc'] == pp_name]
for _, row in df_pp.iterrows():
i = datasets.index(row['dataset_1'])
j = datasets.index(row['dataset_2'])
val = row[metric_col_pp]
pp_matrix[i, j] = val
pp_matrix[j, i] = val
# Use same colormap and scale for comparison
im = ax.imshow(pp_matrix, cmap='YlOrRd', aspect='auto', vmin=raw_matrix.min(), vmax=raw_matrix.max())
# Format preprocessing name
pp_display = pp_name.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(pp_display) > 22:
pp_display = pp_display[:22] + '...'
ax.set_title(f'{pp_display}', fontsize=9, fontweight='bold')
ax.set_xticks(range(n_datasets))
ax.set_yticks(range(n_datasets))
ax.set_xticklabels(datasets, rotation=45, ha='right', fontsize=8)
ax.set_yticklabels(datasets, fontsize=8)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Add text annotations with adaptive color
vmax = raw_matrix.max()
for i in range(n_datasets):
for j in range(n_datasets):
if i != j:
val = pp_matrix[i, j]
text_color = 'white' if val > 0.8 * vmax else 'black'
# Use scientific notation if value is very small
val_str = f'{val:.2e}' if val < 0.001 else f'{val:.3f}'
ax.text(j, i, val_str,
ha="center", va="center", color=text_color, fontsize=6, fontweight='bold')
# Hide unused subplots
for idx in range(n_plots, len(axes)):
axes[idx].axis('off')
plt.suptitle(f'Inter-Dataset {metric_display} Matrices\n(Lower values = better transfer learning potential)',
fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
return fig
[docs]
def plot_distance_reduction_ranking(self, metric='centroid', log_scale=False, figsize=(14, 8)):
"""
Bar chart showing which preprocessing methods best reduce inter-dataset distances.
Directly answers: "Which preprocessing is best for transfer learning?"
Args:
metric: 'centroid' or 'spread' - which distance method to use for ranking
log_scale: If True, use log scale for the right plot (absolute distances) to handle extreme values
"""
if self.cross_dataset_df_ is None or self.cross_dataset_df_.empty:
logger.warning("No cross-dataset analysis available. Need multiple datasets.")
return None
# Compute average distance reduction for each preprocessing
results = []
# Select metric columns and define display name with computation method
metric_col_raw = f'{metric}_dist_raw'
metric_col_pp = f'{metric}_dist_pp'
if metric == 'centroid':
metric_display = 'Centroid Distance'
metric_method = 'Method: Euclidean L2-norm between PCA centroids'
elif metric == 'spread':
metric_display = 'Spread Distance'
metric_method = 'Method: Frobenius norm (covariance) + sample-wise Euclidean'
else:
metric_display = metric.capitalize() + ' Distance'
metric_method = ''
# Get raw distances
raw_dists = self.cross_dataset_df_.groupby(['dataset_1', 'dataset_2'])[metric_col_raw].first()
avg_raw_dist = raw_dists.mean()
for pp_name in self.cross_dataset_df_['preproc'].unique():
df_pp = self.cross_dataset_df_[self.cross_dataset_df_['preproc'] == pp_name]
avg_pp_dist = df_pp[metric_col_pp].mean()
reduction = ((avg_raw_dist - avg_pp_dist) / avg_raw_dist) * 100 # Percentage
results.append({
'preproc': pp_name,
'avg_distance': avg_pp_dist,
'reduction_pct': reduction,
'raw_distance': avg_raw_dist
})
df_results = pd.DataFrame(results).sort_values('reduction_pct', ascending=False)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Plot 1: Distance reduction percentage
labels = []
for pp in df_results['preproc'].values:
formatted = pp.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(formatted) > 30:
formatted = formatted[:30] + '...'
labels.append(formatted)
y_pos = np.arange(len(labels))
colors = ['green' if x > 0 else 'red' for x in df_results['reduction_pct'].values]
bars = ax1.barh(y_pos, df_results['reduction_pct'].values, color=colors, alpha=0.7, edgecolor='black', linewidth=0.8)
ax1.set_yticks(y_pos)
ax1.set_yticklabels(labels, fontsize=8)
ax1.set_xlabel(f'{metric_display} Reduction (%)', fontsize=11, fontweight='bold')
ax1.set_title(f'Transfer Learning Potential\n{metric_display} | {metric_method}\n(Higher = Better)',
fontsize=11, fontweight='bold', pad=15)
ax1.axvline(0, color='black', linewidth=1.5, linestyle='--')
ax1.grid(axis='x', alpha=0.3)
ax1.set_facecolor('#f8f9fa')
# Apply log scale to handle extreme negative values
if log_scale:
# Use symlog to handle both positive and negative values
ax1.set_xscale('symlog', linthresh=1.0)
ax1.set_xlabel(f'{metric_display} Reduction (%, symlog scale)', fontsize=11, fontweight='bold')
# Add value labels
for i, (bar, val) in enumerate(zip(bars, df_results['reduction_pct'].values)):
label_x = val + (2 if val > 0 else -2)
ha = 'left' if val > 0 else 'right'
ax1.text(label_x, bar.get_y() + bar.get_height() / 2, f'{val:.1f}%',
ha=ha, va='center', fontsize=8, fontweight='bold')
# Plot 2: Absolute distances
x = np.arange(len(df_results))
width = 0.35
bars1 = ax2.bar(x - width/2, [avg_raw_dist] * len(df_results), width,
label='Raw', alpha=0.8, color='steelblue', edgecolor='black', linewidth=0.8)
bars2 = ax2.bar(x + width/2, df_results['avg_distance'].values, width,
label='Preprocessed', alpha=0.8, color='coral', edgecolor='black', linewidth=0.8)
ax2.set_xticks(x)
ax2.set_xticklabels(labels, rotation=45, ha='right', fontsize=7)
ax2.set_ylabel(f'Average {metric_display}', fontsize=11, fontweight='bold')
ax2.set_title(f'Absolute {metric_display} Comparison\n{metric_method}',
fontsize=11, fontweight='bold', pad=15)
ax2.legend(fontsize=10, framealpha=0.9)
ax2.grid(axis='y', alpha=0.3)
ax2.set_facecolor('#f8f9fa')
# Apply log scale if requested (helps with extreme values)
if log_scale:
ax2.set_yscale('log')
ax2.set_ylabel(f'Average {metric_display} (log scale)', fontsize=11, fontweight='bold')
plt.suptitle('Preprocessing Ranking for Transfer Learning\n' +
f'{metric_display} - Best preprocessing reduces distance between datasets',
fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])
return fig
[docs]
def plot_quality_metric_convergence(self, figsize=(16, 10)):
"""
Visualize how preprocessing makes quality metrics more homogeneous across datasets.
Shows variance reduction in EVR, CKA, RV, Procrustes, Trustworthiness, Grassmann.
Lower variance after preprocessing = datasets behave more similarly = better for transfer learning.
"""
if self.df_ is None or self.df_.empty or len(self.df_['dataset'].unique()) < 2:
logger.warning("Need at least 2 datasets for quality metric convergence analysis.")
return None
convergence_df = self.get_quality_metric_convergence()
# Get convergence columns
quality_metrics = ['evr_pre', 'cka', 'rv', 'procrustes', 'trustworthiness', 'grassmann']
metric_display_names = ['EVR', 'CKA', 'RV', 'Procrustes*', 'Trustworthiness', 'Grassmann*']
# Sort by average convergence
avg_convergence = convergence_df[[f'{m}_convergence' for m in quality_metrics]].mean(axis=1)
convergence_df['avg_convergence'] = avg_convergence
convergence_df = convergence_df.sort_values('avg_convergence', ascending=False)
# Format preprocessing names
labels = []
for pp in convergence_df['preproc'].values:
formatted = pp.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(formatted) > 30:
formatted = formatted[:30] + '...'
labels.append(formatted)
fig, axes = plt.subplots(2, 3, figsize=figsize)
axes = axes.flatten()
colors_map = ['#2ecc71', '#3498db', '#9b59b6', '#e74c3c', '#f39c12', '#34495e']
for idx, (metric, display_name, color) in enumerate(zip(quality_metrics, metric_display_names, colors_map)):
ax = axes[idx]
convergence_vals = convergence_df[f'{metric}_convergence'].values
y_pos = np.arange(len(labels))
# Color based on positive (green) or negative (red) convergence
bar_colors = ['green' if v > 0 else 'red' for v in convergence_vals]
bars = ax.barh(y_pos, convergence_vals, color=bar_colors, alpha=0.7,
edgecolor='black', linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels, fontsize=7)
ax.set_xlabel('Variance Reduction', fontsize=9, fontweight='bold')
# Add note about inversion for distance metrics
subtitle = '(+ = datasets more similar)'
if metric in ['procrustes', 'grassmann']:
subtitle = '(inverted*, + = datasets more similar)'
ax.set_title(f'{display_name} Convergence\n{subtitle}',
fontsize=10, fontweight='bold', pad=10)
ax.axvline(0, color='black', linewidth=1.5, linestyle='--')
ax.grid(axis='x', alpha=0.3)
ax.set_facecolor('#f8f9fa')
# Add value labels for significant convergence
for i, (bar, val) in enumerate(zip(bars, convergence_vals)):
if abs(val) > 0.1: # Only show if significant
label_x = val + (0.02 if val > 0 else -0.02)
ha = 'left' if val > 0 else 'right'
ax.text(label_x, bar.get_y() + bar.get_height()/2, f'{val:.2f}',
ha=ha, va='center', fontsize=6, fontweight='bold')
plt.suptitle('Quality Metric Convergence Across Datasets\n' +
'How preprocessing makes datasets more homogeneous in quality characteristics\n' +
'(* = inverted for consistency; Positive = variance reduction = better for transfer learning)',
fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.97])
return fig
[docs]
def plot_pair(self, dataset: str, preproc: str, figsize=(10, 5)):
"""Enhanced comparison plot for a specific dataset-preprocessing pair."""
if (dataset, preproc) not in self.cache_:
raise ValueError(f"No data for ({dataset}, {preproc}). Run fit() first.")
Zr, Zp = self.cache_[(dataset, preproc)]
Ar, Ap, disparity = procrustes(Zr[:, :2], Zp[:, :2])
# Get metrics
row = self.df_[(self.df_['dataset'] == dataset) & (self.df_['preproc'] == preproc)].iloc[0]
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Raw PCA
axes[0].scatter(Ar[:, 0], Ar[:, 1], s=50, alpha=0.6, c='steelblue', edgecolors='black', linewidth=0.5)
axes[0].set_xlabel('PC1', fontweight='bold')
axes[0].set_ylabel('PC2', fontweight='bold')
axes[0].set_title(f'{dataset} - Raw PCA\nEVR: {row["evr_raw"]:.4f}', fontweight='bold')
axes[0].grid(alpha=0.3, linestyle='--')
axes[0].set_facecolor('#f8f9fa')
axes[0].set_aspect('equal', 'box')
# Preprocessed PCA (Procrustes aligned)
axes[1].scatter(Ap[:, 0], Ap[:, 1], s=50, alpha=0.6, c='coral', edgecolors='black', linewidth=0.5)
axes[1].set_xlabel('PC1', fontweight='bold')
axes[1].set_ylabel('PC2', fontweight='bold')
# Format preprocessing name for readability
pp_display = preproc.split('|')[-1].replace('MinMax>', '').replace('>', ' → ')
axes[1].set_title(f'{pp_display}\nEVR: {row["evr_pre"]:.4f}', fontweight='bold', fontsize=10)
axes[1].grid(alpha=0.3, linestyle='--')
axes[1].set_facecolor('#f8f9fa')
axes[1].set_aspect('equal', 'box')
# Add metrics text box
metrics_text = (f'CKA: {row["cka"]:.4f}\n'
f'RV: {row["rv"]:.4f}\n'
f'Procrustes: {row["procrustes"]:.4f}\n'
f'Trust: {row["trustworthiness"]:.4f}')
fig.text(0.5, 0.02, metrics_text, ha='center', fontsize=10,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
plt.suptitle(f'PCA Comparison: {dataset} / {preproc}',
fontsize=13, fontweight='bold', y=0.98)
plt.tight_layout(rect=[0, 0.08, 1, 0.96])
return fig
[docs]
def plot_preservation_summary(self, by="preproc", figsize=(14, 8)):
"""Enhanced summary plot with better styling."""
if self.df_ is None or self.df_.empty:
raise ValueError("Run fit() first.")
agg = self.df_.groupby(by).agg({
'evr_pre': 'mean', 'grassmann': 'mean', 'cka': 'mean',
'rv': 'mean', 'procrustes': 'mean', 'trustworthiness': 'mean'
}).reset_index()
# flip distances, min-max normalize (handle NaN values from incompatible feature spaces)
agg['grassmann'] = -agg['grassmann']
agg['procrustes'] = -agg['procrustes']
for c in ['evr_pre', 'grassmann', 'cka', 'rv', 'procrustes', 'trustworthiness']:
v = agg[c].values
valid_mask = ~np.isnan(v)
if valid_mask.sum() > 0:
v_min, v_max = v[valid_mask].min(), v[valid_mask].max()
rng = v_max - v_min
if rng > 1e-12:
agg[c] = np.where(valid_mask, (v - v_min) / rng, np.nan)
else:
agg[c] = np.where(valid_mask, 0.5, np.nan)
metrics = ['evr_pre', 'cka', 'rv', 'trustworthiness', 'grassmann', 'procrustes']
metric_labels = ['EVR', 'CKA', 'RV', 'Trust', 'Grassmann*', 'Procrustes*']
colors_map = ['#2ecc71', '#3498db', '#9b59b6', '#e74c3c', '#f39c12', '#34495e']
x = np.arange(len(agg[by]))
w = 0.13
fig, ax = plt.subplots(figsize=figsize)
for i, (m, label, color) in enumerate(zip(metrics, metric_labels, colors_map)):
values = agg[m].values
offset = (i - 2.5) * w
bars = ax.bar(x + offset, values, w, label=label, color=color,
alpha=0.8, edgecolor='black', linewidth=0.5)
# Add value labels on top of bars (only for non-NaN)
for j, (bar, val) in enumerate(zip(bars, values)):
if not np.isnan(val) and val > 0.05:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{val:.2f}', ha='center', va='bottom', fontsize=7, rotation=0)
# Styling - format preprocessing names for readability
ax.set_xticks(x)
labels = []
for label in agg[by].values:
# Extract meaningful part and format
formatted = label.split('|')[-1].replace('MinMax>', '').replace('>', '→')
# Limit to reasonable length but keep it readable
if len(formatted) > 25:
formatted = formatted[:25] + '...'
labels.append(formatted)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
ax.set_ylim(0, 1.15)
ax.set_ylabel('Normalized Score (0-1)', fontsize=11, fontweight='bold')
ax.set_xlabel(f'{by.capitalize()}', fontsize=11, fontweight='bold')
ax.legend(loc='upper left', fontsize=10, framealpha=0.95, ncol=3)
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_facecolor('#f8f9fa')
ax.set_title(f'Preprocessing Structure Preservation by {by.capitalize()}\n' +
'(* inverted: higher is better)',
fontsize=13, fontweight='bold', pad=15)
plt.tight_layout()
return fig
[docs]
def plot_cross_dataset_distances(self, figsize=(14, 8)):
"""
Plot how preprocessing affects inter-dataset distances.
Shows which preprocessing methods bring datasets closer together.
"""
if self.cross_dataset_df_ is None or self.cross_dataset_df_.empty:
logger.warning("No cross-dataset analysis available. Need multiple datasets.")
return None
summary = self.get_cross_dataset_summary()
fig, axes = plt.subplots(2, 2, figsize=figsize)
# Format preprocessing names for display
labels = []
for pp in summary['preproc'].values:
formatted = pp.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(formatted) > 30:
formatted = formatted[:30] + '...'
labels.append(formatted)
x = np.arange(len(summary))
# 1. Centroid improvement
ax = axes[0, 0]
colors = ['green' if v > 0 else 'red' for v in summary['centroid_improv_mean']]
bars = ax.barh(x, summary['centroid_improv_mean'], xerr=summary['centroid_improv_std'],
color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_yticks(x)
ax.set_yticklabels(labels, fontsize=8)
ax.set_xlabel('Centroid Improvement', fontweight='bold')
ax.set_title('Dataset Centroid Distance Change\n(+: datasets closer, -: datasets farther)',
fontsize=10, fontweight='bold')
ax.axvline(0, color='black', linewidth=1, linestyle='--')
ax.grid(axis='x', alpha=0.3)
ax.set_facecolor('#f8f9fa')
# 2. Spread improvement
ax = axes[0, 1]
colors = ['green' if v > 0 else 'red' for v in summary['spread_improv_mean']]
bars = ax.barh(x, summary['spread_improv_mean'], xerr=summary['spread_improv_std'],
color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_yticks(x)
ax.set_yticklabels(labels, fontsize=8)
ax.set_xlabel('Spread Improvement', fontweight='bold')
ax.set_title('Dataset Distribution Overlap Change\n(+: distributions closer, -: distributions farther)',
fontsize=10, fontweight='bold')
ax.axvline(0, color='black', linewidth=1, linestyle='--')
ax.grid(axis='x', alpha=0.3)
ax.set_facecolor('#f8f9fa')
# 3. Absolute centroid distances
ax = axes[1, 0]
raw_dists = self.cross_dataset_df_.groupby('preproc')['centroid_dist_raw'].mean()
pp_dists = summary['centroid_dist_pp'].values
width = 0.35
x_pos = np.arange(len(summary))
ax.bar(x_pos - width/2, raw_dists.values, width, label='Raw',
alpha=0.8, color='steelblue', edgecolor='black', linewidth=0.5)
ax.bar(x_pos + width/2, pp_dists, width, label='Preprocessed',
alpha=0.8, color='coral', edgecolor='black', linewidth=0.5)
ax.set_xticks(x_pos)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=7)
ax.set_ylabel('Distance', fontweight='bold')
ax.set_title('Absolute Centroid Distances', fontsize=10, fontweight='bold')
ax.legend(loc='best', fontsize=9)
ax.grid(axis='y', alpha=0.3)
ax.set_facecolor('#f8f9fa')
# 4. Absolute spread distances
ax = axes[1, 1]
raw_spread = self.cross_dataset_df_.groupby('preproc')['spread_dist_raw'].mean()
pp_spread = summary['spread_dist_pp'].values
ax.bar(x_pos - width/2, raw_spread.values, width, label='Raw',
alpha=0.8, color='steelblue', edgecolor='black', linewidth=0.5)
ax.bar(x_pos + width/2, pp_spread, width, label='Preprocessed',
alpha=0.8, color='coral', edgecolor='black', linewidth=0.5)
ax.set_xticks(x_pos)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=7)
ax.set_ylabel('Distance', fontweight='bold')
ax.set_title('Absolute Spread Distances', fontsize=10, fontweight='bold')
ax.legend(loc='best', fontsize=9)
ax.grid(axis='y', alpha=0.3)
ax.set_facecolor('#f8f9fa')
plt.suptitle('Cross-Dataset Distance Analysis\n' +
'Evaluating Preprocessing Impact on Multi-Machine Compatibility',
fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
return fig
[docs]
def plot_cross_dataset_heatmap(self, metric='centroid_improvement', figsize=(12, 10)):
"""
Create a heatmap showing pairwise dataset distances for each preprocessing.
Args:
metric: 'centroid_improvement', 'centroid_dist_pp', 'spread_improvement', or 'spread_dist_pp'
"""
if self.cross_dataset_df_ is None or self.cross_dataset_df_.empty:
logger.warning("No cross-dataset analysis available. Need multiple datasets.")
return None
# Get unique datasets and preprocessings
datasets = sorted(set(self.cross_dataset_df_['dataset_1']).union(
set(self.cross_dataset_df_['dataset_2'])))
preprocs = sorted(self.cross_dataset_df_['preproc'].unique())
n_preprocs = len(preprocs)
n_datasets = len(datasets)
# Determine subplot layout
n_cols = min(3, n_preprocs)
n_rows = (n_preprocs + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
if n_preprocs == 1:
axes = np.array([axes])
axes = axes.flatten()
# Create heatmap for each preprocessing
for idx, pp_name in enumerate(preprocs):
ax = axes[idx]
# Create distance matrix
dist_matrix = np.zeros((n_datasets, n_datasets))
df_pp = self.cross_dataset_df_[self.cross_dataset_df_['preproc'] == pp_name]
for _, row in df_pp.iterrows():
i = datasets.index(row['dataset_1'])
j = datasets.index(row['dataset_2'])
val = row[metric]
dist_matrix[i, j] = val
dist_matrix[j, i] = val # Symmetric
# Choose colormap based on metric
if 'improvement' in metric:
cmap = 'RdYlGn' # Red (worse) to Green (better)
vmin, vmax = -1, 1
else:
cmap = 'YlOrRd' # Yellow (close) to Red (far)
vmin, vmax = None, None
im = ax.imshow(dist_matrix, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
# Format preprocessing name
pp_display = pp_name.split('|')[-1].replace('MinMax>', '').replace('>', '→')
if len(pp_display) > 40:
pp_display = pp_display[:40] + '...'
ax.set_title(pp_display, fontsize=8, fontweight='bold')
ax.set_xticks(range(n_datasets))
ax.set_yticks(range(n_datasets))
ax.set_xticklabels(datasets, rotation=45, ha='right', fontsize=7)
ax.set_yticklabels(datasets, fontsize=7)
# Add colorbar
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Add text annotations
for i in range(n_datasets):
for j in range(n_datasets):
if i != j and dist_matrix[i, j] != 0:
text = ax.text(j, i, f'{dist_matrix[i, j]:.2f}',
ha="center", va="center", color="black", fontsize=6)
# Hide unused subplots
for idx in range(n_preprocs, len(axes)):
axes[idx].axis('off')
metric_display = metric.replace('_', ' ').title()
plt.suptitle(f'Cross-Dataset {metric_display} Heatmaps\n' +
'Comparing Dataset Compatibility Across Preprocessing Methods',
fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
return fig
[docs]
def plot_all(self, show=True):
"""Generate all visualization plots."""
figs = []
# # 1. Summary comparison
# print("📊 Generating summary comparison...")
# figs.append(self.plot_summary(by="preproc"))
# # 2. PCA scatter plots
# print("📈 Generating PCA scatter plots...")
# figs.append(self.plot_pca_scatter())
# # 3. Distance network
# print("🕸️ Generating similarity network...")
# figs.append(self.plot_distance_network(metric='cka'))
if show:
plt.show()
return figs