"""Online Koopman Latent-Mode PLS (OKLM-PLS) regressor for nirs4all.
A sklearn-compatible implementation of OKLM-PLS that combines Koopman operator
theory with PLS for time-series regression. The model learns latent dynamics
T_{t+1} ≈ F @ T_t while simultaneously fitting Y_t ≈ T_t @ B.
This is useful for spectral data collected over time where temporal coherence
provides additional information for prediction.
Supports both NumPy (CPU) and JAX (GPU/TPU) backends.
References
----------
.. [1] Brunton, S. L., Budisic, M., Kaiser, E., & Kutz, J. N. (2021).
Modern Koopman theory for dynamical systems. arXiv:2102.12086.
.. [2] Williams, M. O., Kevrekidis, I. G., & Rowley, C. W. (2015).
A data-driven approximation of the Koopman operator: Extending
dynamic mode decomposition. Journal of Nonlinear Science, 25, 1307-1346.
Mathematical formulation
------------------------
Let
- X ∈ ℝ^{n×p} be the input matrix of n samples and p features
(e.g. NIRS spectra), ordered in time t = 1,…,n.
- Y ∈ ℝ^{n×q} be the corresponding response matrix.
An optional feature map ψ : ℝ^p → ℝ^d is applied to each row x_t,
giving Z ∈ ℝ^{n×d} with rows
z_t = ψ(x_t), t = 1,…,n.
OKLM-PLS seeks:
- a loading matrix W ∈ ℝ^{d×r} (r latent components),
- a latent dynamic matrix F ∈ ℝ^{r×r},
- a regression matrix B ∈ ℝ^{r×q},
such that the latent scores
T = Z W ∈ ℝ^{n×r}, t_t = row t of T,
(1) capture covariance between X (via Z) and Y in the PLS sense, (2) follow
a simple linear dynamics in latent space, and (3) predict Y linearly.
The latent dynamics is modeled as a Koopman-like linear evolution:
t_{t+1} ≈ F t_t, for t = 1,…,n−1.
The regression in latent space is
y_t ≈ Bᵀ t_t, for t = 1,…,n,
or in matrix form
Y ≈ T B.
A simple joint objective is
L(W, F, B)
= λ_dyn ∑_{t=1}^{n−1} ‖ t_{t+1} − F t_t ‖₂²
+ λ_reg ∑_{t=1}^{n} ‖ y_t − Bᵀ t_t ‖₂²,
subject to PLS-style constraints on W and T (e.g. orthonormal scores and
components ordered by decreasing covariance with Y):
(1/n) Tᵀ T = I_r,
components sorted by cov(T_j, Y).
In practice, the objective is optimized by alternating updates:
1) Latent scores
Given W, we form Z and T = Z W.
2) Dynamics update
Given T, we estimate F as the least-squares solution of
F = argmin_F ∑_{t=1}^{n−1} ‖ t_{t+1} − F t_t ‖₂²,
which has the closed form
F = (∑_{t} t_{t+1} t_tᵀ) (∑_{t} t_t t_tᵀ)^{-1}.
3) Regression update
Given T, we estimate B as the least-squares solution of
B = argmin_B ‖ Y − T B ‖_F²,
giving
B = (Tᵀ T)^{-1} Tᵀ Y.
4) PLS-like loading update
W is initialized from a standard PLS on (Z, Y), and can optionally be
refined by (approximate) gradient-based updates on L(W, F, B), while
enforcing column normalization and PLS-style orthogonality constraints.
The fitted model predicts Y for new inputs X* by:
1) applying the same preprocessing and feature map ψ to obtain Z*,
2) computing scores T* = Z* W,
3) returning Ŷ* = T* B (with inverse scaling if standardization is used).
The term “online” refers to the fact that, for streaming data, F and B
can be updated recursively as new scores t_t arrive, while W is kept
fixed or updated more infrequently.
"""
from __future__ import annotations
from functools import partial
from typing import Literal, Union
import numpy as np
from numpy.typing import ArrayLike, NDArray
from sklearn.base import BaseEstimator, RegressorMixin, TransformerMixin
from sklearn.cross_decomposition import PLSRegression
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_is_fitted
def _check_jax_available():
"""Check if JAX is available for GPU acceleration."""
try:
import jax
return True
except ImportError:
return False
# =============================================================================
# Featurizers
# =============================================================================
[docs]
class IdentityFeaturizer(BaseEstimator, TransformerMixin):
"""Identity featurizer: ψ(x) = x.
This is the default featurizer for OKLMPLS when no nonlinear
transformation is needed.
"""
[docs]
def fit(self, X, y=None):
"""Fit the featurizer (no-op for identity).
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : ignored
Returns
-------
self : IdentityFeaturizer
"""
return self
[docs]
class PolynomialFeaturizer(BaseEstimator, TransformerMixin):
"""Polynomial featurizer for OKLM-PLS.
Creates polynomial features up to specified degree without interaction
terms (for efficiency with high-dimensional spectral data).
Parameters
----------
degree : int, default=2
Maximum degree of polynomial features.
include_original : bool, default=True
Whether to include the original features (degree 1).
"""
def __init__(self, degree: int = 2, include_original: bool = True):
self.degree = degree
self.include_original = include_original
[docs]
def fit(self, X, y=None):
"""Fit the featurizer.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : ignored
Returns
-------
self : PolynomialFeaturizer
"""
return self
[docs]
class RBFFeaturizer(BaseEstimator, TransformerMixin):
"""Random Fourier Features (RBF approximation) featurizer for OKLM-PLS.
Approximates the RBF kernel using random Fourier features, which is
useful for adding nonlinearity to the Koopman embedding.
Parameters
----------
n_components : int, default=100
Number of random Fourier features.
gamma : float, optional
Kernel coefficient. If None, uses 1/n_features.
random_state : int, optional
Random seed for reproducibility.
"""
def __init__(
self,
n_components: int = 100,
gamma: float | None = None,
random_state: int | None = None,
):
self.n_components = n_components
self.gamma = gamma
self.random_state = random_state
[docs]
def fit(self, X, y=None):
"""Fit the featurizer by sampling random frequencies.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : ignored
Returns
-------
self : RBFFeaturizer
"""
X = np.asarray(X)
n_features = X.shape[1]
gamma = self.gamma if self.gamma is not None else 1.0 / n_features
rng = np.random.default_rng(self.random_state)
# Sample from Gaussian with std = sqrt(2 * gamma)
self.random_weights_ = rng.normal(
0, np.sqrt(2 * gamma), (n_features, self.n_components)
)
self.random_offset_ = rng.uniform(
0, 2 * np.pi, self.n_components
)
return self
# =============================================================================
# NumPy Backend Implementation
# =============================================================================
def _oklmpls_fit_numpy(
Z: NDArray[np.floating],
Y: NDArray[np.floating],
n_components: int,
lambda_dyn: float,
lambda_reg_y: float,
max_iter: int,
tol: float,
W_init: NDArray[np.floating] | None = None,
B_init: NDArray[np.floating] | None = None,
) -> tuple[
NDArray[np.floating], # W (projection weights)
NDArray[np.floating], # F (dynamics matrix)
NDArray[np.floating], # B (regression coefficients)
int, # n_iter (actual iterations)
]:
"""Fit OKLM-PLS model using alternating optimization.
Parameters
----------
Z : ndarray of shape (n_samples, n_features_z)
Featurized and centered X data.
Y : ndarray of shape (n_samples, n_targets)
Centered Y data.
n_components : int
Number of latent components.
lambda_dyn : float
Weight for dynamic consistency loss.
lambda_reg_y : float
Weight for regression loss.
max_iter : int
Maximum alternating optimization iterations.
tol : float
Convergence tolerance.
W_init : ndarray, optional
Initial W weights.
B_init : ndarray, optional
Initial B coefficients.
Returns
-------
W : ndarray of shape (n_features_z, n_components)
Learned projection weights.
F : ndarray of shape (n_components, n_components)
Learned dynamics matrix.
B : ndarray of shape (n_components, n_targets)
Learned regression coefficients.
n_iter : int
Number of iterations until convergence.
"""
n_samples, d = Z.shape
n_targets = Y.shape[1]
r = n_components
# Initialize W and B
if W_init is not None:
W = W_init.copy()
else:
# Random initialization with orthogonalization
rng = np.random.default_rng(42)
W = rng.normal(size=(d, r))
W, _ = np.linalg.qr(W)
if B_init is not None:
B = B_init.copy()
else:
B = np.zeros((r, n_targets), dtype=np.float64)
# Initialize F as identity
F = np.eye(r, dtype=np.float64)
prev_loss = np.inf
for iteration in range(max_iter):
# Step 1: Compute latent scores T = Z @ W
T = Z @ W # (n_samples, r)
# Step 2: Update dynamics matrix F
# Minimize ||T_{t+1} - F @ T_t||^2
# Solution: F = (T_{t+1}^T @ T_t) @ (T_t^T @ T_t)^{-1}
if lambda_dyn > 0 and n_samples > 1:
T_t = T[:-1] # (n_samples-1, r)
T_next = T[1:] # (n_samples-1, r)
# F^T = (T_t^T @ T_t)^{-1} @ T_t^T @ T_next
TtT = T_t.T @ T_t + 1e-8 * np.eye(r)
TtN = T_t.T @ T_next
F_T = np.linalg.solve(TtT, TtN)
F = F_T.T
# Step 3: Update regression coefficients B
# Minimize ||Y - T @ B||^2
# Solution: B = (T^T @ T)^{-1} @ T^T @ Y
if lambda_reg_y > 0:
TtT = T.T @ T + 1e-8 * np.eye(r)
TtY = T.T @ Y
B = np.linalg.solve(TtT, TtY)
# Step 4: Update W using gradient descent
# This is the key step that makes OKLM-PLS different from standard PLS
# We minimize: lambda_dyn * ||T_{t+1} - F @ T_t||^2 + lambda_reg_y * ||Y - T @ B||^2
# Gradient w.r.t. W is complex, so we use an approximate update
# Compute gradient of combined objective
T = Z @ W # Recompute T with current W
# Gradient from regression loss: -2 * Z^T @ (Y - T @ B) @ B^T
residual_y = Y - T @ B
grad_reg = -2 * Z.T @ residual_y @ B.T
# Gradient from dynamics loss
grad_dyn = np.zeros_like(W)
if lambda_dyn > 0 and n_samples > 1:
T_t = T[:-1]
T_next = T[1:]
residual_dyn = T_next - T_t @ F.T
# Gradient contributions
Z_t = Z[:-1]
Z_next = Z[1:]
# d/dW ||T_next - F @ T_t||^2
# = d/dW ||Z_next @ W - F @ Z_t @ W||^2
grad_dyn = (
-2 * Z_next.T @ residual_dyn
+ 2 * Z_t.T @ (T_t @ F.T - T_next) @ F.T
)
# Combined gradient
grad = lambda_reg_y * grad_reg + lambda_dyn * grad_dyn
# Adaptive step size with line search
step_size = 0.01 / (1 + 0.1 * iteration)
# Update W with gradient step
W_new = W - step_size * grad
# Orthogonalize W to maintain well-conditioned projection
W_new, _ = np.linalg.qr(W_new)
W = W_new
# Compute loss for convergence check
T = Z @ W
dyn_loss = 0.0
if lambda_dyn > 0 and n_samples > 1:
T_t = T[:-1]
T_next = T[1:]
dyn_loss = np.sum((T_next - T_t @ F.T) ** 2)
reg_loss = 0.0
if lambda_reg_y > 0:
Y_hat = T @ B
reg_loss = np.sum((Y - Y_hat) ** 2)
loss = lambda_dyn * dyn_loss + lambda_reg_y * reg_loss
if np.abs(prev_loss - loss) < tol:
break
prev_loss = loss
return W, F, B, iteration + 1
def _oklmpls_predict_numpy(
Z: NDArray[np.floating],
W: NDArray[np.floating],
B: NDArray[np.floating],
) -> NDArray[np.floating]:
"""Predict using OKLM-PLS model.
Parameters
----------
Z : ndarray of shape (n_samples, n_features_z)
Featurized and centered X data.
W : ndarray of shape (n_features_z, n_components)
Projection weights.
B : ndarray of shape (n_components, n_targets)
Regression coefficients.
Returns
-------
Y_pred : ndarray of shape (n_samples, n_targets)
Predicted values.
"""
T = Z @ W
return T @ B
# =============================================================================
# JAX Backend Implementation
# =============================================================================
def _get_jax_oklmpls_functions():
"""Lazy import and create JAX OKLM-PLS functions."""
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
jax.config.update("jax_enable_x64", True)
@jax.jit
def compute_scores_jax(Z, W):
"""Compute latent scores T = Z @ W."""
return Z @ W
def update_dynamics_jax(T, r):
"""Update dynamics matrix F."""
T_t = T[:-1]
T_next = T[1:]
TtT = T_t.T @ T_t + 1e-8 * jnp.eye(r)
TtN = T_t.T @ T_next
F_T = jnp.linalg.solve(TtT, TtN)
return F_T.T
def update_regression_jax(T, Y, r):
"""Update regression coefficients B."""
TtT = T.T @ T + 1e-8 * jnp.eye(r)
TtY = T.T @ Y
return jnp.linalg.solve(TtT, TtY)
@jax.jit
def compute_loss_jax(Z, Y, W, F, B, lambda_dyn, lambda_reg_y):
"""Compute total loss."""
T = Z @ W
n_samples = T.shape[0]
# Dynamic loss
T_t = T[:-1]
T_next = T[1:]
dyn_loss = jnp.sum((T_next - T_t @ F.T) ** 2)
# Regression loss
reg_loss = jnp.sum((Y - T @ B) ** 2)
return lambda_dyn * dyn_loss + lambda_reg_y * reg_loss
@jax.jit
def compute_gradient_jax(Z, Y, W, F, B, lambda_dyn, lambda_reg_y):
"""Compute gradient of loss w.r.t. W."""
T = Z @ W
# Gradient from regression
residual_y = Y - T @ B
grad_reg = -2 * Z.T @ residual_y @ B.T
# Gradient from dynamics
T_t = T[:-1]
T_next = T[1:]
residual_dyn = T_next - T_t @ F.T
Z_t = Z[:-1]
Z_next = Z[1:]
grad_dyn = (
-2 * Z_next.T @ residual_dyn
+ 2 * Z_t.T @ (T_t @ F.T - T_next) @ F.T
)
return lambda_reg_y * grad_reg + lambda_dyn * grad_dyn
@jax.jit
def predict_jax(Z, W, B):
"""Predict using fitted model."""
T = Z @ W
return T @ B
return {
'compute_scores': compute_scores_jax,
'update_dynamics': update_dynamics_jax,
'update_regression': update_regression_jax,
'compute_loss': compute_loss_jax,
'compute_gradient': compute_gradient_jax,
'predict': predict_jax,
}
_JAX_OKLMPLS_FUNCS = None
def _get_cached_jax_oklmpls():
"""Get cached JAX OKLM-PLS functions."""
global _JAX_OKLMPLS_FUNCS
if _JAX_OKLMPLS_FUNCS is None:
_JAX_OKLMPLS_FUNCS = _get_jax_oklmpls_functions()
return _JAX_OKLMPLS_FUNCS
def _oklmpls_fit_jax(
Z: NDArray[np.floating],
Y: NDArray[np.floating],
n_components: int,
lambda_dyn: float,
lambda_reg_y: float,
max_iter: int,
tol: float,
W_init: NDArray[np.floating] | None = None,
B_init: NDArray[np.floating] | None = None,
) -> tuple[
NDArray[np.floating],
NDArray[np.floating],
NDArray[np.floating],
int,
]:
"""Fit OKLM-PLS model using JAX backend."""
import jax.numpy as jnp
jax_funcs = _get_cached_jax_oklmpls()
n_samples, d = Z.shape
n_targets = Y.shape[1]
r = n_components
# Convert to JAX arrays
Z_jax = jnp.asarray(Z)
Y_jax = jnp.asarray(Y)
# Initialize W and B
if W_init is not None:
W = jnp.asarray(W_init)
else:
rng = np.random.default_rng(42)
W_np = rng.normal(size=(d, r))
W_np, _ = np.linalg.qr(W_np)
W = jnp.asarray(W_np)
if B_init is not None:
B = jnp.asarray(B_init)
else:
B = jnp.zeros((r, n_targets), dtype=jnp.float64)
F = jnp.eye(r, dtype=jnp.float64)
prev_loss = np.inf
for iteration in range(max_iter):
# Update F
T = jax_funcs['compute_scores'](Z_jax, W)
if lambda_dyn > 0 and n_samples > 1:
F = jax_funcs['update_dynamics'](T, r)
# Update B
if lambda_reg_y > 0:
B = jax_funcs['update_regression'](T, Y_jax, r)
# Update W
grad = jax_funcs['compute_gradient'](
Z_jax, Y_jax, W, F, B, lambda_dyn, lambda_reg_y
)
step_size = 0.01 / (1 + 0.1 * iteration)
W_new = W - step_size * grad
# Orthogonalize
W_np, _ = np.linalg.qr(np.asarray(W_new))
W = jnp.asarray(W_np)
# Check convergence
loss = float(jax_funcs['compute_loss'](
Z_jax, Y_jax, W, F, B, lambda_dyn, lambda_reg_y
))
if np.abs(prev_loss - loss) < tol:
break
prev_loss = loss
return np.asarray(W), np.asarray(F), np.asarray(B), iteration + 1
# =============================================================================
# OKLMPLS Estimator Class
# =============================================================================
[docs]
class OKLMPLS(BaseEstimator, RegressorMixin):
"""Online Koopman Latent-Mode Partial Least Squares (OKLM-PLS).
OKLM-PLS combines Koopman operator theory with PLS for time-series
regression. It learns latent scores T = ψ(X) @ W and simultaneously:
- Enforces dynamic coherence: T_{t+1} ≈ F @ T_t
- Learns regression: Y_t ≈ T_t @ B
This is useful for spectral data collected over time where temporal
coherence provides additional predictive information.
Parameters
----------
n_components : int, default=5
Number of latent components.
featurizer : TransformerMixin, optional
Feature map ψ: X -> Z. If None, identity is used.
Options include PolynomialFeaturizer and RBFFeaturizer.
lambda_dyn : float, default=1.0
Weight for dynamic consistency loss ||T_{t+1} - F @ T_t||².
Higher values enforce stronger temporal coherence.
lambda_reg_y : float, default=1.0
Weight for regression loss ||Y - T @ B||².
max_iter : int, default=50
Maximum alternating optimization iterations.
tol : float, default=1e-4
Convergence tolerance on the objective function.
warm_start_pls : bool, default=True
If True, initialize W/B from a standard PLSRegression fit.
standardize : bool, default=True
Whether to standardize X and Y before fitting.
backend : str, default='numpy'
Computational backend:
- 'numpy': NumPy backend (CPU only).
- 'jax': JAX backend (supports GPU/TPU).
random_state : int, optional
Random seed for initialization.
Attributes
----------
n_features_in_ : int
Number of features in input X.
n_components_ : int
Actual number of components.
W_ : ndarray of shape (n_features_z, n_components_)
Projection weights (in featurized space).
F_ : ndarray of shape (n_components_, n_components_)
Dynamics matrix for latent scores.
B_ : ndarray of shape (n_components_, n_targets)
Regression coefficients.
n_iter_ : int
Number of iterations until convergence.
Examples
--------
>>> from nirs4all.operators.models.sklearn.oklmpls import OKLMPLS
>>> import numpy as np
>>> # Generate time-series data
>>> np.random.seed(42)
>>> X = np.random.randn(100, 50)
>>> y = X[:, :5].sum(axis=1) + 0.1 * np.random.randn(100)
>>> # Fit OKLM-PLS
>>> model = OKLMPLS(n_components=10, lambda_dyn=1.0, lambda_reg_y=1.0)
>>> model.fit(X, y)
OKLMPLS(...)
>>> predictions = model.predict(X)
>>> # Use with polynomial featurizer for nonlinearity
>>> from nirs4all.operators.models.sklearn.oklmpls import PolynomialFeaturizer
>>> model_poly = OKLMPLS(n_components=10, featurizer=PolynomialFeaturizer(degree=2))
>>> model_poly.fit(X, y)
Notes
-----
OKLM-PLS is designed for temporally-ordered data where samples are
sequential in time. The dynamics constraint helps capture temporal
patterns and can improve prediction when the underlying process
has smooth temporal evolution.
For non-temporal data, set lambda_dyn=0 to disable the dynamics
constraint (equivalent to standard PLS with optional featurization).
See Also
--------
SIMPLS : Standard PLS without dynamics.
RecursivePLS : Online PLS with forgetting factor.
"""
# Explicitly declare estimator type for sklearn compatibility (e.g., StackingRegressor)
_estimator_type = "regressor"
def __init__(
self,
n_components: int = 5,
featurizer: TransformerMixin | None = None,
lambda_dyn: float = 1.0,
lambda_reg_y: float = 1.0,
max_iter: int = 50,
tol: float = 1e-4,
warm_start_pls: bool = True,
standardize: bool = True,
backend: str = 'numpy',
random_state: int | None = None,
):
"""Initialize OKLM-PLS regressor."""
self.n_components = n_components
self.featurizer = featurizer
self.lambda_dyn = lambda_dyn
self.lambda_reg_y = lambda_reg_y
self.max_iter = max_iter
self.tol = tol
self.warm_start_pls = warm_start_pls
self.standardize = standardize
self.backend = backend
self.random_state = random_state
def _init_state(
self,
X: NDArray[np.floating],
Y: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]:
"""Initialize preprocessing and featurization.
Returns
-------
Z : ndarray
Featurized and processed X.
X_proc : ndarray
Processed X (before featurization).
Y_proc : ndarray
Processed Y.
"""
# Standardization
if self.standardize:
self.x_scaler_ = StandardScaler(with_mean=True, with_std=True)
self.y_scaler_ = StandardScaler(with_mean=True, with_std=True)
else:
self.x_scaler_ = None
self.y_scaler_ = None
X_proc = X.copy()
Y_proc = Y.copy()
if self.x_scaler_ is not None:
X_proc = self.x_scaler_.fit_transform(X_proc)
if self.y_scaler_ is not None:
Y_proc = self.y_scaler_.fit_transform(Y_proc)
# Featurizer
if self.featurizer is None:
self.featurizer_ = IdentityFeaturizer()
else:
self.featurizer_ = self.featurizer
Z = self.featurizer_.fit_transform(X_proc)
return Z, X_proc, Y_proc
def _warm_start(
self,
Z: NDArray[np.floating],
Y_proc: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""Initialize W and B from PLSRegression.
Returns
-------
W_init : ndarray of shape (n_features_z, n_components)
Initial projection weights.
B_init : ndarray of shape (n_components, n_targets)
Initial regression coefficients.
"""
r = min(self.n_components, Z.shape[0] - 1, Z.shape[1])
pls = PLSRegression(n_components=r)
pls.fit(Z, Y_proc)
# W is the x_weights_ orthogonalized
W_init = pls.x_weights_ # (d, r)
# B in latent space: T @ B ≈ Y, where T = Z @ W
T = pls.x_scores_
TtT = T.T @ T + 1e-8 * np.eye(r)
B_init = np.linalg.solve(TtT, T.T @ Y_proc)
return W_init, B_init
[docs]
def fit(
self,
X: ArrayLike,
y: ArrayLike,
) -> "OKLMPLS":
"""Fit the OKLM-PLS model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data. Samples should be temporally ordered.
y : array-like of shape (n_samples,) or (n_samples, n_targets)
Target values.
Returns
-------
self : OKLMPLS
Fitted estimator.
Raises
------
ValueError
If backend is invalid.
ImportError
If JAX backend requested but not installed.
"""
if self.backend not in ('numpy', 'jax'):
raise ValueError(
f"backend must be 'numpy' or 'jax', got '{self.backend}'"
)
if self.backend == 'jax' and not _check_jax_available():
raise ImportError(
"JAX is required for OKLMPLS with backend='jax'. "
"Install with: pip install jax"
)
X = np.asarray(X, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
self._y_1d = y.ndim == 1
if self._y_1d:
y = y.reshape(-1, 1)
n_samples, n_features = X.shape
n_targets = y.shape[1]
self.n_features_in_ = n_features
# Initialize preprocessing
Z, X_proc, Y_proc = self._init_state(X, y)
# Limit components
n_features_z = Z.shape[1]
max_components = min(n_samples - 1, n_features_z)
self.n_components_ = min(self.n_components, max_components)
# Warm start initialization
W_init = None
B_init = None
if self.warm_start_pls:
W_init, B_init = self._warm_start(Z, Y_proc)
# Fit using appropriate backend
if self.backend == 'jax':
W, F, B, n_iter = _oklmpls_fit_jax(
Z, Y_proc, self.n_components_,
self.lambda_dyn, self.lambda_reg_y,
self.max_iter, self.tol,
W_init, B_init,
)
else:
W, F, B, n_iter = _oklmpls_fit_numpy(
Z, Y_proc, self.n_components_,
self.lambda_dyn, self.lambda_reg_y,
self.max_iter, self.tol,
W_init, B_init,
)
self.W_ = W
self.F_ = F
self.B_ = B
self.n_iter_ = n_iter
return self
[docs]
def predict(self, X: ArrayLike) -> NDArray[np.floating]:
"""Predict using the OKLM-PLS model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples to predict.
Returns
-------
y_pred : ndarray of shape (n_samples,) or (n_samples, n_targets)
Predicted values.
"""
check_is_fitted(self, ['W_', 'B_', 'featurizer_'])
X = np.asarray(X, dtype=np.float64)
# Preprocess
X_proc = X.copy()
if self.x_scaler_ is not None:
X_proc = self.x_scaler_.transform(X_proc)
Z = self.featurizer_.transform(X_proc)
# Predict
if self.backend == 'jax':
import jax.numpy as jnp
jax_funcs = _get_cached_jax_oklmpls()
Y_proc = np.asarray(jax_funcs['predict'](
jnp.asarray(Z), jnp.asarray(self.W_), jnp.asarray(self.B_)
))
else:
Y_proc = _oklmpls_predict_numpy(Z, self.W_, self.B_)
# Inverse transform
if self.y_scaler_ is not None:
Y = self.y_scaler_.inverse_transform(Y_proc)
else:
Y = Y_proc
if self._y_1d:
Y = Y.ravel()
return Y
[docs]
def predict_dynamic(
self,
X: ArrayLike,
n_steps: int = 1,
) -> NDArray[np.floating]:
"""Predict using dynamics model for future timesteps.
Given the last sample's latent scores, predict future values
using the learned dynamics T_{t+1} = F @ T_t.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Current data. Uses last sample for propagation.
n_steps : int, default=1
Number of future timesteps to predict.
Returns
-------
y_future : ndarray of shape (n_steps, n_targets)
Predicted future values.
"""
check_is_fitted(self, ['W_', 'F_', 'B_'])
T = self.transform(X)
T_current = T[-1:] # Last timestep
predictions = []
for _ in range(n_steps):
T_next = T_current @ self.F_.T
y_next = T_next @ self.B_
if self.y_scaler_ is not None:
y_next = self.y_scaler_.inverse_transform(y_next)
predictions.append(y_next.flatten())
T_current = T_next
result = np.array(predictions)
# If univariate y and output is (n_steps, 1), squeeze to (n_steps,)
if self._y_1d and result.ndim == 2 and result.shape[1] == 1:
result = result.ravel()
return result
[docs]
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator."""
return {
'n_components': self.n_components,
'featurizer': self.featurizer,
'lambda_dyn': self.lambda_dyn,
'lambda_reg_y': self.lambda_reg_y,
'max_iter': self.max_iter,
'tol': self.tol,
'warm_start_pls': self.warm_start_pls,
'standardize': self.standardize,
'backend': self.backend,
'random_state': self.random_state,
}
[docs]
def set_params(self, **params) -> "OKLMPLS":
"""Set the parameters of this estimator."""
for key, value in params.items():
setattr(self, key, value)
return self
[docs]
def __repr__(self) -> str:
"""Return string representation."""
return (
f"OKLMPLS(n_components={self.n_components}, "
f"lambda_dyn={self.lambda_dyn}, "
f"lambda_reg_y={self.lambda_reg_y}, "
f"max_iter={self.max_iter}, "
f"backend='{self.backend}')"
)