From a57c6f19a217b0a74507273395f35ff986d51efa Mon Sep 17 00:00:00 2001 From: hvoss Date: Thu, 4 Dec 2025 16:05:12 +0100 Subject: [PATCH 01/18] added torch solver variant --- examples/benchmarks/lorenz_benchmark.py | 172 +++++++ examples/benchmarks/spring_benchmark.py | 185 +++++++ pyproject.toml | 1 + pysindy/optimizers/__init__.py | 6 + pysindy/optimizers/torch_solver.py | 493 +++++++++++++++++++ test/test_optimizers/test_torch_optimizer.py | 52 ++ 6 files changed, 909 insertions(+) create mode 100644 examples/benchmarks/lorenz_benchmark.py create mode 100644 examples/benchmarks/spring_benchmark.py create mode 100644 pysindy/optimizers/torch_solver.py create mode 100644 test/test_optimizers/test_torch_optimizer.py diff --git a/examples/benchmarks/lorenz_benchmark.py b/examples/benchmarks/lorenz_benchmark.py new file mode 100644 index 00000000..0bd66898 --- /dev/null +++ b/examples/benchmarks/lorenz_benchmark.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +Benchmark PySINDy optimizers on the Lorenz system. +Runs a quick RK4 simulation, constructs a SINDy model with a PolynomialLibrary, +then evaluates multiple optimizers (including TorchOptimizer if torch is available) +for runtime, model score, and sparsity (complexity). +""" +import time +import warnings +from typing import List, Tuple + +import numpy as np + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import ( + STLSQ, + SR3, + FROLS, + SSR, + WrappedOptimizer, +) + +# Optional optimizers guarded in __init__ (may be None if dependency missing) +try: + from pysindy.optimizers import TorchOptimizer # type: ignore +except Exception: + TorchOptimizer = None # type: ignore + +try: + from pysindy.optimizers import SBR # may require cvxpy or extra deps +except Exception: + SBR = None # type: ignore + + +def lorenz(t, x, sigma: float = 10.0, beta: float = 8.0 / 3.0, rho: float = 28.0): + u, v, w = x + up = -sigma * (u - v) + vp = rho * u - v - u * w + wp = -beta * w + u * v + return np.array([up, vp, wp], dtype=float) + + +def rk4_lorenz(t0: float = 0.0, t1: float = 10.0, dt: float = 0.01) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + t = np.arange(t0, t1 + 1e-12, dt) + X = np.zeros((t.size, 3), dtype=float) + X[0, :] = np.array([1.0, 1.0, 1.0], dtype=float) + # RK4 integrator + for i in range(1, t.size): + ti = t[i - 1] + xi = X[i - 1, :] + k1 = lorenz(ti, xi) + k2 = lorenz(ti + dt / 2.0, xi + dt * k1 / 2.0) + k3 = lorenz(ti + dt / 2.0, xi + dt * k2 / 2.0) + k4 = lorenz(ti + dt, xi + dt * k3) + X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) + + # Drop a short burn-in to avoid transient initialization effects + burn = int(1.0 / dt) # drop first 1s + t = t[burn:] + X = X[burn:, :] + + # Exact RHS at sample points using analytic lorenz + dX_dt = np.vstack([lorenz(tt, xx) for tt, xx in zip(t, X)]) + return t, X, dX_dt + + +def build_library() -> PolynomialLibrary: + # Standard SINDy polynomial library for Lorenz (degree 2 is typical) + return PolynomialLibrary(degree=2, include_interaction=True) + + +def run_optimizer(name: str, optimizer, X: np.ndarray, dt: float, library: PolynomialLibrary, dX_dt: np.ndarray): + model = SINDy(optimizer=optimizer, feature_library=library) + t0 = time.perf_counter() + model.fit(X, t=dt) + fit_time = time.perf_counter() - t0 + score = model.score(X, t=dt) + # Predict derivatives from learned model, compare to analytic RHS + dX_pred = model.predict(X) + mse = float(np.mean((dX_pred - dX_dt) ** 2)) + complexity = optimizer.complexity if hasattr(optimizer, "complexity") else None + # Gather equations as strings + try: + equations = model.equations() + except Exception: + equations = None + return { + "name": name, + "fit_time_s": fit_time, + "score": float(score), + "mse": mse, + "complexity": int(complexity) if complexity is not None else None, + "equations": equations, + } + + +def system_lorenz_example() -> None: + """Simulate the Lorenz attractor and test PySINDy + our solvers on it. + + Uses a simple RK4 integrator (no extra deps). To keep runtime reasonable, + solver options are reduced compared to defaults but can be adjusted. + """ + warnings.filterwarnings("ignore") + t, X, dX_dt = rk4_lorenz() + dt = t[1] - t[0] + library = build_library() + + optimizers = [] # type: List[Tuple[str, object]] + + # STLSQ: classic baseline + optimizers.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.05, max_iter=20))) + + # SR3: relaxed regularized regression, L0 prox behaves like hard thresholding + optimizers.append(("SR3-L0", SR3(reg_weight_lam=0.1, regularizer="L0", relax_coeff_nu=1.0, max_iter=50))) + + # FROLS & SSR: forward regression variants + optimizers.append(("FROLS", FROLS())) + optimizers.append(("SSR", SSR())) + + # SBR if available + if SBR is not None: + try: + optimizers.append(("SBR", SBR())) + except Exception: + pass + + # Torch-based optimizer if available + if TorchOptimizer is not None: + try: + optimizers.append(( + "TorchOptimizer", + TorchOptimizer( + seed=0, + ), + )) + except Exception: + pass + + results = [] + for name, opt in optimizers: + try: + res = run_optimizer(name, opt, X, dt, library, dX_dt) + except Exception as e: + res = {"name": name, "error": str(e)} + results.append(res) + + # Pretty print summary + header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" + print(header) + print("-" * len(header)) + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") + # Print discovered system equations + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") + eqs = r.get("equations") + if eqs: + for eq in eqs: + print(f" {eq}") + else: + print(" (equations unavailable)") + + +if __name__ == "__main__": + system_lorenz_example() diff --git a/examples/benchmarks/spring_benchmark.py b/examples/benchmarks/spring_benchmark.py new file mode 100644 index 00000000..2ded9f00 --- /dev/null +++ b/examples/benchmarks/spring_benchmark.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Benchmark PySINDy optimizers on a nonlinear spring system. +System: + xdot = v + vdot = -k * x - c * v + F * sin(x**2) +Simulates with RK4, evaluates multiple optimizers, prints metrics and equations, +and saves a plot comparing true vs predicted test trajectories. +""" +import traceback +from pathlib import Path +import time +import warnings +from typing import List, Tuple + +import numpy as np +import matplotlib.pyplot as plt + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import ( + STLSQ, + SR3, + FROLS, + SSR, +) +# Optional optimizers +try: + from pysindy.optimizers import TorchOptimizer # type: ignore +except Exception: + TorchOptimizer = None # type: ignore + +try: + from pysindy.optimizers import SBR +except Exception: + SBR = None # type: ignore + + +def myspring(t, x, k=-4.518, c=0.372, F0=9.123): + """ + Example nonlinear dynamical system. + xdot = v + vdot = - k x - v c + F sin(x**2) + """ + return np.array([x[1], k * x[0] - c * x[1] + F0 * np.sin(x[0] ** 2)]) + + +def rk4_system(f, x0: np.ndarray, t0: float, t1: float, dt: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + t = np.arange(t0, t1 + 1e-12, dt) + X = np.zeros((t.size, x0.size), dtype=float) + X[0, :] = np.array(x0, dtype=float) + for i in range(1, t.size): + ti = t[i - 1] + xi = X[i - 1, :] + k1 = f(ti, xi) + k2 = f(ti + dt / 2.0, xi + dt * k1 / 2.0) + k3 = f(ti + dt / 2.0, xi + dt * k2 / 2.0) + k4 = f(ti + dt, xi + dt * k3) + X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) + dX_dt = np.vstack([f(tt, xx) for tt, xx in zip(t, X)]) + return t, X, dX_dt + + +def split(arr: np.ndarray, ratio: float) -> Tuple[np.ndarray, np.ndarray]: + n = arr.shape[0] + n_tr = int(np.floor(ratio * n)) + return arr[:n_tr], arr[n_tr:] + + +def build_library() -> PolynomialLibrary: + # Include sin via generalized or custom library? Approximate with polynomials up to degree 5 + return PolynomialLibrary(degree=5, include_interaction=True) + + +def run_optimizer(name: str, optimizer, X_tr: np.ndarray, dt: float, library: PolynomialLibrary, + X_te: np.ndarray, dX_dt_te: np.ndarray): + model = SINDy(optimizer=optimizer, feature_library=library) + t0 = time.perf_counter() + model.fit(X_tr, t=dt) + fit_time = time.perf_counter() - t0 + score = model.score(X_te, t=dt) + dX_pred_te = model.predict(X_te) + mse = float(np.mean((dX_pred_te - dX_dt_te) ** 2)) + complexity = getattr(optimizer, 'complexity', None) + try: + equations = model.equations() + except Exception: + equations = None + return { + "name": name, + "fit_time_s": fit_time, + "score": float(score), + "mse": mse, + "complexity": int(complexity) if complexity is not None else None, + "equations": equations, + "model": model, + } + + +def nonlinear_spring_example() -> None: + warnings.filterwarnings("ignore") + # Simulate + t0, t1, dt = 0.0, 10.0, 0.01 + x0 = np.array([0.4, 1.6], dtype=float) + t, X, dX_dt = rk4_system(myspring, x0, t0, t1, dt) + # Train/test split + ratio = 0.67 + X_tr, X_te = split(X, ratio) + dX_dt_tr, dX_dt_te = split(dX_dt, ratio) + # Build library + library = build_library() + + optimizers: List[Tuple[str, object]] = [] + optimizers.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.01, max_iter=30))) + optimizers.append(("SR3-L0", SR3(reg_weight_lam=0.05, regularizer="L0", relax_coeff_nu=1.0, max_iter=100))) + optimizers.append(("FROLS", FROLS())) + optimizers.append(("SSR", SSR())) + if SBR is not None: + try: + optimizers.append(("SBR", SBR())) + except Exception: + pass + if TorchOptimizer is not None: + try: + optimizers.append(( + "TorchOptimizer", + TorchOptimizer( + seed=0, + ), + )) + except Exception: + pass + + dt_scalar = dt + results = [] + for name, opt in optimizers: + try: + res = run_optimizer(name, opt, X_tr, dt_scalar, library, X_te, dX_dt_te) + except Exception as e: + res = {"name": name, "error": str(e)} + results.append(res) + + # Print summary + header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" + print(header) + print("-" * len(header)) + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") + eqs = r.get("equations") + if eqs: + for eq in eqs: + print(f" {eq}") + else: + print(" (equations unavailable)") + + # Plot true vs predicted for the best (by MSE) model + good = [r for r in results if "error" not in r] + if good: + best = min(good, key=lambda r: r["mse"]) # type: ignore + model = best["model"] + X_te_pred = model.simulate(X_te[0], t[len(X_tr):]) + # Build figure data + fig, axs = plt.subplots(2, 1, figsize=(8, 6), sharex=True) + tt = t[len(X_tr):] + axs[0].plot(tt, X_te[:, 0], label="x true") + axs[0].plot(tt, X_te_pred[:, 0], label="x pred", linestyle="--") + axs[0].set_ylabel("x") + axs[0].legend() + axs[1].plot(tt, X_te[:, 1], label="v true") + axs[1].plot(tt, X_te_pred[:, 1], label="v pred", linestyle="--") + axs[1].set_ylabel("v") + axs[1].set_xlabel("t") + axs[1].legend() + fig.suptitle(f"Nonlinear spring - best: {best['name']} (MSE={best['mse']:.3e}, Score={best['score']:.3f})") + out_dir = Path(__file__).parents[0].joinpath('figures') + out_dir.mkdir(exist_ok=True) + fig.savefig(out_dir.joinpath('nonlinear_spring.svg'), dpi=300) + plt.show() + + +if __name__ == "__main__": + nonlinear_spring_example() diff --git a/pyproject.toml b/pyproject.toml index 19a7fdbf..a018af80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "derivative>=0.6.2", "scipy", "typing_extensions", + "torch", ] [project.optional-dependencies] diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index d4346b4f..85313f34 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -30,6 +30,11 @@ from .sbr import SBR except (ImportError, NameError): pass +try: + from .torch_solver import TorchOptimizer +except Exception: + # Optional dependency torch may be missing or cause import-time errors + TorchOptimizer = None __all__ = [ @@ -46,4 +51,5 @@ "SINDyPI", "MIOSR", "SBR", + "TorchOptimizer", ] diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py new file mode 100644 index 00000000..ea91c24f --- /dev/null +++ b/pysindy/optimizers/torch_solver.py @@ -0,0 +1,493 @@ +""" +PyTorch-based SINDy optimizer using proximal gradient + iterative hard-thresholding. + +This module provides a high-performance optimizer implemented with PyTorch to solve +Sparse Identification of Nonlinear Dynamics (SINDy) regression problems. It +minimizes a smooth data-fit term and applies proximal/thresholding operations to +promote sparsity in the discovered dynamical system. + +Key features +- Batched multi-target regression on CPU or GPU (if available). +- Proximal L1 shrinkage and hard thresholding to encourage sparse models. +- Optional optimizers: SGD, Adam, AdamW, and custom CAdamW. +- Best-solution tracking across iterations and early-stopping support. +- Compatible with PySINDy BaseOptimizer interface and ensembling. + +Optional dependencies +- PyTorch is optional at import time; a RuntimeError will be raised during fit if + torch is not available. Code paths and annotations avoid import-time failures. +- CAdamW is a custom optimizer shipped with the project and used when requested. + +Usage +----- +Example: Fit a SINDy model with the Torch optimizer. + + >>> import numpy as np + >>> from pysindy import SINDy + >>> from pysindy.feature_library import PolynomialLibrary + >>> from pysindy.optimizers.torch_solver import TorchOptimizer + >>> rng = np.random.default_rng(0) + >>> X = rng.standard_normal((500, 3)) + >>> Y = X @ np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]).T + >>> opt = TorchOptimizer(threshold=0.05, alpha_l1=1e-3, step_size=1e-2, max_iter=500, early_stopping_patience=50) + >>> lib = PolynomialLibrary(degree=2) + >>> model = SINDy(optimizer=opt, feature_library=lib) + >>> model.fit(X, t=0.01) + >>> model.equations() # doctest: +ELLIPSIS + [...] + +Notes +----- +- Thresholding and proximal operations operate on coefficient magnitudes. A small + numerical threshold (1e-14) is used to derive support masks for `ind_`. +- When `sparse_ind` is provided, thresholding affects only the specified columns. +- Early stopping halts iterations when the objective fails to improve by at least + `min_delta` for `early_stopping_patience` consecutive steps. +- The optimizer tracks and restores the best solution observed across iterations. +""" + +import warnings +from typing import Optional, TYPE_CHECKING + +import numpy as np + +from .base import BaseOptimizer + +try: + import torch # type: ignore +except Exception: # pragma: no cover - optional dependency + torch = None # type: ignore + +if TYPE_CHECKING: # only for type checkers + import torch as torch_types # noqa: F401 + + +import math +from typing import Callable, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer + +# Taken from https://github.com/kyleliang919/C-Optim/blob/main/c_adamw.py +class CAdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self.init_lr = lr + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for i, p in enumerate(group["params"]): + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # apply weight decay + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + # compute norm gradient + mask = (exp_avg * grad > 0).to(grad.dtype) + # mask = mask * (mask.numel() / (mask.sum() + 1)) ## original implementation, leaving it here for record + mask.div_( + mask.mean().clamp_(min=1e-3) + ) # https://huggingface.co/rwightman/timm-optim-caution found this implementation is more favourable in many cases + norm_grad = (exp_avg * mask) / denom + p.add_(norm_grad, alpha=-step_size) + return loss + + +def _soft_threshold(t, lam: float): + """Soft-thresholding proximal operator. + + Applies element-wise soft-shrinkage: sign(t) * max(|t| - lam, 0). + + Parameters + - t: torch.Tensor with coefficients. + - lam: float threshold (lambda) controlling the shrinkage amount. + + Returns + - torch.Tensor of same shape as `t` after soft-thresholding. + """ + if lam <= 0: + return t + return torch.sign(t) * torch.clamp(torch.abs(t) - lam, min=0.0) + + +def _hard_threshold(t, thr: float): + """Hard-thresholding operator. + + Zeros out elements whose absolute value is below the given threshold. + + Parameters + - t: torch.Tensor with coefficients. + - thr: float magnitude threshold. + + Returns + - torch.Tensor of same shape as `t` with small entries set to zero. + """ + if thr <= 0: + return t + return t * (torch.abs(t) >= thr) + + +class TorchOptimizer(BaseOptimizer): + """Torch-powered optimizer for sparse SINDy regression. + + This optimizer minimizes the objective + + J(W) = (1/N) * ||Y - X W^T||_F^2 + alpha_l1 * ||W||_1 + + using gradient-based updates (SGD/Adam/AdamW/CAdamW), followed by proximal + soft-thresholding (L1) and hard-thresholding to encourage sparsity. It supports + multi-target regression (rows correspond to targets) and optional GPU acceleration. + + Parameters + ---------- + threshold : float, default 1e-1 + Minimum magnitude for a coefficient. Values with |coef| < threshold + are set to zero after each iteration (hard-threshold). + alpha_l1 : float, default 0.0 + L1 penalty weight. If > 0, soft-thresholding is applied after the + gradient step to shrink coefficients. + step_size : float, default 1e-1 + Learning rate for the chosen Torch optimizer. + max_iter : int, default 1000 + Maximum number of iterations. + optimizer : {"sgd", "adam", "adamw", "cadamw"}, default "adam" + Which optimizer to use for the smooth part of the objective. + normalize_columns : bool, default False + See BaseOptimizer; if True, columns of X are normalized before fitting. + copy_X : bool, default True + See BaseOptimizer; controls whether X is copied or may be overwritten. + initial_guess : np.ndarray or None, default None + Warm-start coefficients; shape (n_targets, n_features). + verbose : bool, default False + If True, prints periodic progress including loss and sparsity. + device : {"cpu", "cuda"} or None, default None + Torch device to use. If None, uses CPU; if "cuda" is requested but not + available, falls back to CPU with a warning. + seed : int or None, default None + Random seed for reproducibility of Torch and NumPy. + sparse_ind : list[int] or None, default None + If provided, thresholding only applies to these feature indices. Other + indices remain unaffected by hard-thresholding. + unbias : bool, default True + See BaseOptimizer; when True, performs an unbiased refit on the selected + support after optimization. + early_stopping_patience : int, default 0 + If > 0, stop early when the objective has not improved by at least + `min_delta` for this many consecutive iterations. + min_delta : float, default 0.0 + Minimum improvement to reset patience; small positive values help + prevent stopping on floating-point noise. + + Attributes + ---------- + coef_ : np.ndarray, shape (n_targets, n_features) + Final coefficients (best observed if early stopping is enabled). + ind_ : np.ndarray[bool], shape (n_targets, n_features) + Support mask where True indicates non-zero (above tiny threshold). + history_ : list[np.ndarray] + Sequence of coefficient matrices recorded at each iteration. + intercept_ : float or np.ndarray + Intercept term (handled by BaseOptimizer; not fit here). + + Notes + ----- + - Best-solution tracking: the optimizer records the coefficient state with + the lowest objective value and restores it at the end. + - Early-stopping: controlled by `early_stopping_patience` and `min_delta`. + - The objective combines mean-squared error and optional L1 penalty. + - For reproducible results, set `seed` and avoid non-deterministic CUDA ops. + """ + + def __init__( + self, + threshold: float = 1e-1, + alpha_l1: float = 0.0, + step_size: float = 1e-1, + max_iter: int = 1000, + optimizer: str = "adam", + normalize_columns: bool = False, + copy_X: bool = True, + initial_guess: Optional[np.ndarray] = None, + verbose: bool = False, + device: Optional[str] = None, + seed: Optional[int] = None, + sparse_ind: Optional[list[int]] = None, + unbias: bool = True, + early_stopping_patience: int = 0, + min_delta: float = 0.0, + ): + super().__init__( + max_iter=max_iter, + normalize_columns=normalize_columns, + initial_guess=initial_guess, + copy_X=copy_X, + unbias=unbias, + ) + if threshold < 0: + raise ValueError("threshold cannot be negative") + if alpha_l1 < 0: + raise ValueError("alpha_l1 cannot be negative") + if step_size <= 0: + raise ValueError("step_size must be positive") + if optimizer not in ("sgd", "adam", "adamw", "cadamw"): + raise ValueError("optimizer must be 'sgd', 'adam', 'adamw', or 'cadamw'") + if early_stopping_patience < 0: + raise ValueError("early_stopping_patience cannot be negative") + if min_delta < 0: + raise ValueError("min_delta cannot be negative") + self.threshold = float(threshold) + self.alpha_l1 = float(alpha_l1) + self.step_size = float(step_size) + self.verbose = bool(verbose) + self.torch_device = device or "cpu" + self.seed = seed + self.opt_name = optimizer + self.sparse_ind = sparse_ind + self.early_stopping_patience = int(early_stopping_patience) + self.min_delta = float(min_delta) + if torch is None: + # Delay hard failure to fit-time to allow import of module without torch + warnings.warn( + "PyTorch is not installed; TorchSINDyOptimizer will not run until torch is available.") + + def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: + """Core optimization loop. + + This method performs up to `max_iter` iterations of gradient-based updates + on the smooth objective, followed by proximal soft-thresholding and hard + thresholding to enforce sparsity. It maintains a history of coefficients, + tracks the best objective value, and optionally stops early. + + Parameters + ---------- + x : np.ndarray, shape (n_samples, n_features) + Feature matrix (SINDy library output). Normalization may be applied + by the BaseOptimizer depending on constructor options. + y : np.ndarray, shape (n_samples, n_targets) + Target matrix (derivatives). + + Side effects + ------------ + - Updates `self.coef_` with the best observed coefficients. + - Updates `self.ind_` as a boolean support mask. + - Appends snapshots to `self.history_` each iteration. + + Raises + ------ + ImportError + If PyTorch is not installed at run time. + """ + if torch is None: + raise ImportError("PyTorch is required for TorchSINDyOptimizer. Please install torch.") + # Select device + if self.torch_device == "cuda" and not torch.cuda.is_available(): + warnings.warn("CUDA not available; falling back to CPU.") + device = torch.device("cpu") + else: + device = torch.device(self.torch_device) + # Seed control + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + # Data to torch + dtype = torch.float64 # match numpy default precision + X = torch.as_tensor(x, dtype=dtype, device=device) + Y = torch.as_tensor(y, dtype=dtype, device=device) + n_samples, n_features = X.shape + n_targets = Y.shape[1] + + # Parameter tensor W: shape (n_targets, n_features) + if self.coef_ is None: + W = torch.zeros((n_targets, n_features), dtype=dtype, device=device) + else: + W = torch.as_tensor(self.coef_, dtype=dtype, device=device) + W.requires_grad_(True) + + # Optimizer for smooth loss + if self.opt_name == "adam": + opt = torch.optim.Adam([W], lr=self.step_size) + elif self.opt_name == "adamw": + opt = torch.optim.AdamW([W], lr=self.step_size) + elif self.opt_name == "cadamw": + opt = CAdamW([W], lr=self.step_size) + else: + opt = torch.optim.SGD([W], lr=self.step_size, momentum=0.9) + + # Support mask helper: restrict thresholding to specified indices + sparse_mask = None + if self.sparse_ind is not None: + sparse_mask = torch.zeros((n_targets, n_features), dtype=torch.bool, device=device) + sparse_mask[:, self.sparse_ind] = True + + def loss_fn(W_): + """Compute objective: MSE/N + alpha_l1 * ||W||_1.""" + Y_pred = X @ W_.T # (n_samples, n_targets) + residual = Y_pred - Y + mse = (residual.pow(2)).sum() / n_samples + if self.alpha_l1 > 0: + l1 = self.alpha_l1 * torch.abs(W_).sum() + else: + l1 = torch.zeros((), dtype=dtype, device=device) + return mse + l1 + + last_mask = None + best_obj = None + best_W = None + patience_counter = 0 + + # Simple gradient stepping just as you would do in learning a model on pytorch + for it in range(self.max_iter): + opt.zero_grad(set_to_none=True) + loss = loss_fn(W) + loss.backward() + # Gradient step via optimizer + opt.step() + # Prox for L1 + if self.alpha_l1 > 0: + with torch.no_grad(): + W[:] = _soft_threshold(W, self.alpha_l1 * self.step_size) + # Hard-threshold + with torch.no_grad(): + if self.threshold > 0: + if sparse_mask is None: + W[:] = _hard_threshold(W, self.threshold) + else: + kept = _hard_threshold(W, self.threshold) + W[:] = torch.where(sparse_mask, kept, W) + # Evaluate objective and update best + with torch.no_grad(): + obj = float(loss_fn(W).cpu().numpy()) + if best_obj is None or (best_obj - obj) > self.min_delta: + best_obj = obj + best_W = W.detach().clone() + patience_counter = 0 + else: + patience_counter += 1 + # Track history and early stop if support unchanged or patience exceeded + with torch.no_grad(): + coef_np = W.detach().cpu().numpy() + self.history_.append(coef_np) + mask = np.abs(coef_np) >= max(self.threshold, 1e-14) + if last_mask is not None and np.array_equal(mask, last_mask): + break + last_mask = mask + if 0 < self.early_stopping_patience <= patience_counter: + break + if self.verbose and (it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1): + mse_val = float(((X @ W.T - Y).pow(2)).sum().cpu().numpy()) / n_samples + l0 = int((torch.abs(W) >= self.threshold).sum().item()) + print(f"[TorchSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") + + # Final coefficients back to numpy: use best if available + final_W = (best_W if best_W is not None else W).detach().cpu().numpy() + self.coef_ = final_W + # ind_ based on tiny threshold + self.ind_ = np.abs(self.coef_) > 1e-14 + + @property + def complexity(self): + """Model complexity measure. + + Returns the number of non-zero coefficients plus the number of non-zero + intercepts (if any). This should allow comparing sparsity across optimizers. + + Returns + ------- + int + Complexity score: count_nonzero(coef_) + count_nonzero(intercept_). + """ + return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) diff --git a/test/test_optimizers/test_torch_optimizer.py b/test/test_optimizers/test_torch_optimizer.py new file mode 100644 index 00000000..13708e46 --- /dev/null +++ b/test/test_optimizers/test_torch_optimizer.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from pysindy.optimizers import TorchOptimizer +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary + + +def make_synthetic(n_samples=200, noise=0.0, seed=0): + rng = np.random.default_rng(seed) + X = rng.normal(size=(n_samples, 3)) + # True W + W = np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]) + Y = X @ W.T + noise * rng.normal(size=(n_samples, 3)) + return X, Y + + +def test_basic_fit_shapes(): + X, Y = make_synthetic() + opt = TorchOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=1) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + assert opt.ind_.shape == (Y.shape[1], X.shape[1]) + assert len(opt.history_) >= 1 + + +def test_unbias_and_sparsity(): + X, Y = make_synthetic(noise=0.01) + opt = TorchOptimizer(max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True) + opt.fit(X, Y) + # Check some sparsity present + assert np.count_nonzero(opt.coef_) < opt.coef_.size + + +def test_multi_target_with_sindylib(): + # Integration with SINDy and a small library + rng = np.random.default_rng(0) + t = np.linspace(0, 1, 200) + x = np.stack([ + np.sin(2*np.pi*t), + np.cos(2*np.pi*t), + 0.5*np.sin(4*np.pi*t) + ], axis=1) + lib = PolynomialLibrary(degree=2) + opt = TorchOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=0) + model = SINDy(optimizer=opt, feature_library=lib) + model.fit(x, t=t[1]-t[0]) + score = model.score(x, t=t[1]-t[0]) + assert score > 0.8 + From ee389631f19a63cf0a47ec665472670a4b51e9f8 Mon Sep 17 00:00:00 2001 From: hvoss Date: Thu, 4 Dec 2025 16:11:29 +0100 Subject: [PATCH 02/18] made printing the benchmark a bit nicer and removed matplotlib --- examples/benchmarks/spring_benchmark.py | 33 ++++--------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/examples/benchmarks/spring_benchmark.py b/examples/benchmarks/spring_benchmark.py index 2ded9f00..7a64d9e9 100644 --- a/examples/benchmarks/spring_benchmark.py +++ b/examples/benchmarks/spring_benchmark.py @@ -7,14 +7,11 @@ Simulates with RK4, evaluates multiple optimizers, prints metrics and equations, and saves a plot comparing true vs predicted test trajectories. """ -import traceback -from pathlib import Path import time import warnings from typing import List, Tuple import numpy as np -import matplotlib.pyplot as plt from pysindy import SINDy from pysindy.feature_library import PolynomialLibrary @@ -144,6 +141,11 @@ def nonlinear_spring_example() -> None: header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" print(header) print("-" * len(header)) + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") for r in results: if "error" in r: print(f"{r['name']:<15} ERROR: {r['error']}") @@ -156,30 +158,5 @@ def nonlinear_spring_example() -> None: else: print(" (equations unavailable)") - # Plot true vs predicted for the best (by MSE) model - good = [r for r in results if "error" not in r] - if good: - best = min(good, key=lambda r: r["mse"]) # type: ignore - model = best["model"] - X_te_pred = model.simulate(X_te[0], t[len(X_tr):]) - # Build figure data - fig, axs = plt.subplots(2, 1, figsize=(8, 6), sharex=True) - tt = t[len(X_tr):] - axs[0].plot(tt, X_te[:, 0], label="x true") - axs[0].plot(tt, X_te_pred[:, 0], label="x pred", linestyle="--") - axs[0].set_ylabel("x") - axs[0].legend() - axs[1].plot(tt, X_te[:, 1], label="v true") - axs[1].plot(tt, X_te_pred[:, 1], label="v pred", linestyle="--") - axs[1].set_ylabel("v") - axs[1].set_xlabel("t") - axs[1].legend() - fig.suptitle(f"Nonlinear spring - best: {best['name']} (MSE={best['mse']:.3e}, Score={best['score']:.3f})") - out_dir = Path(__file__).parents[0].joinpath('figures') - out_dir.mkdir(exist_ok=True) - fig.savefig(out_dir.joinpath('nonlinear_spring.svg'), dpi=300) - plt.show() - - if __name__ == "__main__": nonlinear_spring_example() From 468862bb47ac7109e70e0ed5610cc8f13f8aab97 Mon Sep 17 00:00:00 2001 From: hvoss Date: Thu, 4 Dec 2025 16:39:59 +0100 Subject: [PATCH 03/18] better benchmark --- examples/benchmarks/benchmark.py | 415 ++++++++++++++++++ examples/benchmarks/lorenz_benchmark.py | 172 -------- examples/benchmarks/spring_benchmark.py | 162 ------- pysindy/optimizers/__init__.py | 3 +- .../{torch_solver.py => torch_s.py} | 53 ++- 5 files changed, 451 insertions(+), 354 deletions(-) create mode 100644 examples/benchmarks/benchmark.py delete mode 100644 examples/benchmarks/lorenz_benchmark.py delete mode 100644 examples/benchmarks/spring_benchmark.py rename pysindy/optimizers/{torch_solver.py => torch_s.py} (92%) diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py new file mode 100644 index 00000000..41854aca --- /dev/null +++ b/examples/benchmarks/benchmark.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" +Generalized benchmark for PySINDy optimizers on multiple nonlinear systems. + +Features +- Unified runner for many ODE systems (Lorenz, Rossler, Van der Pol, Duffing, Chua, + Pendulum, Kuramoto, Hindmarsh–Rose (reduced), FitzHugh–Nagumo, Thomas, Sprott A, + and a nonlinear spring). +- Simple RK4 integrator (no external dependencies). +- Evaluates available optimizers (STLSQ, SR3, FROLS, SSR, optional SBR, optional TorchOptimizer). +- Reports runtime, score, MSE against analytic RHS, complexity, and discovered equations. +- CLI flags to choose system, integration params, library degree, and optimizers. +- Supports running all systems in a single invocation with --system all. + +Usage examples +- Run Lorenz with defaults: + python examples/benchmarks/benchmark.py --system lorenz --dt 0.01 --t1 10 +- List systems: + python examples/benchmarks/benchmark.py --list +- Run all systems on all solvers: + python examples/benchmarks/benchmark.py --system all --dt 0.01 --t1 5 +""" +import argparse +import time +import traceback +import warnings +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import FROLS +from pysindy.optimizers import SR3 +from pysindy.optimizers import SSR +from pysindy.optimizers import STLSQ + +try: + from pysindy.optimizers import SBR +except Exception: + SBR = None # type: ignore +try: + from pysindy.optimizers import TorchOptimizer +except Exception: + TorchOptimizer = None # type: ignore + + +# ------------------------------- Systems ------------------------------------ + + +def lorenz_rhs(t, x, sigma=10.0, beta=8.0 / 3.0, rho=28.0): + u, v, w = x + return np.array( + [ + -sigma * (u - v), + rho * u - v - u * w, + -beta * w + u * v, + ], + dtype=float, + ) + + +def rossler_rhs(t, x, a=0.2, b=0.2, c=5.7): + x1, x2, x3 = x + return np.array( + [ + -(x2 + x3), + x1 + a * x2, + b + x3 * (x1 - c), + ], + dtype=float, + ) + + +def vdp_rhs(t, x, mu=3.0): + x1, x2 = x + return np.array( + [ + x2, + mu * (1 - x1**2) * x2 - x1, + ], + dtype=float, + ) + + +def duffing_rhs(t, x, delta=0.2, alpha=-1.0, beta=1.0, gamma=0.3, omega=1.2): + # Unforced variant (set gamma=0) or keep small forcing + x1, x2 = x + return np.array( + [ + x2, + -delta * x2 - alpha * x1 - beta * x1**3 + gamma * np.cos(omega * t), + ], + dtype=float, + ) + + +def chua_rhs(t, x, alpha=15.6, beta=28.0, m0=-1.143, m1=-0.714): + x1, x2, x3 = x + h = m1 * x1 + 0.5 * (m0 - m1) * (np.abs(x1 + 1) - np.abs(x1 - 1)) + return np.array( + [ + alpha * (x2 - x1 - h), + x1 - x2 + x3, + -beta * x2, + ], + dtype=float, + ) + + +def pendulum_rhs(t, x, g=9.81, L=1.0, q=0.0, F=0.0, omega_d=0.0): + # Damped/forced nonlinear pendulum; default undamped, unforced + theta, omega = x + return np.array( + [ + omega, + -(g / L) * np.sin(theta) - q * omega + F * np.sin(omega_d * t), + ], + dtype=float, + ) + + +def kuramoto_rhs(t, x, K=0.5): + # Small network (3 oscillators) with identical natural freq=0 + # x are phases; coupling via sine differences + n = x.shape[0] + dx = np.zeros(n) + for i in range(n): + dx[i] = (K / n) * np.sum(np.sin(x - x[i])) + return dx + + +def hindmarsh_rose_rhs(t, x): + # Reduced 2D form for benchmarking + x1, x2 = x + a = 1.0 + b = 3.0 + c = 1.0 + d = 5.0 + return np.array( + [ + x2 - a * x1**3 + b * x1**2 + 1.0, # simplified variant + c - d * x1**2 - x2, # simplified variant + ], + dtype=float, + ) + + +def fitzhugh_nagumo_rhs(t, x, a=0.7, b=0.8, tau=12.5, Ia=0.5): + v, w = x + return np.array( + [ + v - v**3 / 3 - w + Ia, + (v + a - b * w) / tau, + ], + dtype=float, + ) + + +def thomas_rhs(t, x, b=0.208186): + x1, x2, x3 = x + return np.array( + [ + np.sin(x2) - b * x1, + np.sin(x3) - b * x2, + np.sin(x1) - b * x3, + ], + dtype=float, + ) + + +def sprott_a_rhs(t, x): + x1, x2, x3 = x + return np.array( + [ + x2, + x3, + -x1 + x2**2 - x3, + ], + dtype=float, + ) + + +def myspring_rhs(t, x, k=-4.518, c=0.372, F0=9.123): + return np.array([x[1], k * x[0] - c * x[1] + F0 * np.sin(x[0] ** 2)], dtype=float) + + +SYSTEMS: Dict[str, Tuple[Callable, np.ndarray]] = { + "lorenz": (lorenz_rhs, np.array([1.0, 1.0, 1.0], dtype=float)), + "rossler": (rossler_rhs, np.array([0.1, 0.1, 0.1], dtype=float)), + "vanderpol": (vdp_rhs, np.array([2.0, 0.0], dtype=float)), + "duffing": (duffing_rhs, np.array([0.1, 0.0], dtype=float)), + "chua": (chua_rhs, np.array([0.1, 0.0, 0.0], dtype=float)), + "pendulum": (pendulum_rhs, np.array([0.5, 0.0], dtype=float)), + "kuramoto3": (kuramoto_rhs, np.array([0.2, -0.3, 0.1], dtype=float)), + "hindmarsh_rose2": (hindmarsh_rose_rhs, np.array([0.0, 0.0], dtype=float)), + "fitzhugh_nagumo": (fitzhugh_nagumo_rhs, np.array([0.0, 0.0], dtype=float)), + "thomas": (thomas_rhs, np.array([0.1, 0.2, 0.3], dtype=float)), + "sprott_a": (sprott_a_rhs, np.array([0.1, 0.1, 0.1], dtype=float)), + "myspring": (myspring_rhs, np.array([0.4, 1.6], dtype=float)), +} + + +def rk4( + f: Callable, x0: np.ndarray, t0: float, t1: float, dt: float +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Integrate an ODE x' = f(t, x) using classical RK4. + + Parameters + - f: callable(t, x) -> np.ndarray, RHS function. + - x0: initial state vector. + - t0: start time. + - t1: end time. + - dt: time step. + + Returns + - t: time array of shape (N,). + - X: state trajectory of shape (N, d). + - dX_dt: analytic RHS evaluated along trajectory, shape (N, d). + """ + t = np.arange(t0, t1 + 1e-12, dt) + X = np.zeros((t.size, x0.size), dtype=float) + X[0, :] = np.array(x0, dtype=float) + for i in range(1, t.size): + ti = t[i - 1] + xi = X[i - 1, :] + k1 = f(ti, xi) + k2 = f(ti + dt / 2.0, xi + dt * k1 / 2.0) + k3 = f(ti + dt / 2.0, xi + dt * k2 / 2.0) + k4 = f(ti + dt, xi + dt * k3) + X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) + dX_dt = np.vstack([f(tt, xx) for tt, xx in zip(t, X)]) + return t, X, dX_dt + + +def build_library(degree: int = 3) -> PolynomialLibrary: + """Construct a polynomial feature library. + + Parameters + - degree: maximum polynomial degree; interactions enabled. + + Returns + - PolynomialLibrary instance. + """ + return PolynomialLibrary(degree=degree, include_interaction=True) + + +def run_optimizer( + name: str, + optimizer, + X: np.ndarray, + dt: float, + library: PolynomialLibrary, + dX_dt: np.ndarray, +): + """Fit a SINDy model and compute metrics for an optimizer. + + Returns a dict with name, runtime, score, MSE vs analytic RHS, complexity, + equations (if available), and the fitted model. + """ + model = SINDy(optimizer=optimizer, feature_library=library) + t0 = time.perf_counter() + model.fit(X, t=dt) + fit_time = time.perf_counter() - t0 + score = model.score(X, t=dt) + dX_pred = model.predict(X) + mse = float(np.mean((dX_pred - dX_dt) ** 2)) + complexity = optimizer.complexity if hasattr(optimizer, "complexity") else None + try: + equations = model.equations() + except Exception: + equations = None + return { + "name": name, + "fit_time_s": fit_time, + "score": float(score), + "mse": mse, + "complexity": int(complexity) if complexity is not None else None, + "equations": equations, + "model": model, + } + + +def main(): + """Entry point for the generalized benchmark runner. + + Parses CLI, builds the library, iterates over selected systems, runs selected + optimizers, prints a summary table, and highlights the best optimizer per system + (lowest MSE), including its discovered equations. + """ + warnings.filterwarnings("ignore") + parser = argparse.ArgumentParser( + description="Generalized nonlinear systems benchmark for PySINDy." + ) + parser.add_argument( + "--system", + type=str, + default="all", + choices=sorted(list(SYSTEMS.keys()) + ["all"]), + ) + parser.add_argument("--t0", type=float, default=0.0) + parser.add_argument("--t1", type=float, default=10.0) + parser.add_argument("--dt", type=float, default=0.01) + parser.add_argument( + "--degree", type=int, default=3, help="Polynomial library degree" + ) + parser.add_argument( + "--optimizers", type=str, default="all", help="Comma-separated list or 'all'" + ) + parser.add_argument( + "--list", action="store_true", help="List available systems and exit" + ) + args = parser.parse_args() + + if args.list: + print("Available systems:") + for name in sorted(SYSTEMS.keys()): + print(f" - {name}") + print(" - all (run every system)") + return + + systems_to_run = ( + [(args.system, SYSTEMS[args.system])] + if args.system != "all" + else list(SYSTEMS.items()) + ) + library = build_library(args.degree) + + # Select optimizers + opt_defs: List[Tuple[str, object]] = [] + opt_defs.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.05, max_iter=20))) + opt_defs.append( + ( + "SR3-L0", + SR3(reg_weight_lam=0.1, regularizer="L0", relax_coeff_nu=1.0, max_iter=50), + ) + ) + opt_defs.append(("FROLS", FROLS())) + opt_defs.append(("SSR", SSR())) + if SBR is not None: + try: + opt_defs.append(("SBR", SBR())) + except Exception: + pass + if TorchOptimizer is not None: + try: + opt_defs.append( + ( + "TorchOptimizer", + TorchOptimizer( + threshold=0.05, + alpha_l1=1e-3, + step_size=1e-2, + max_iter=200, + optimizer="adam", + seed=0, + unbias=True, + early_stopping_patience=50, + min_delta=1e-8, + ), + ) + ) + except Exception: + traceback.print_exc() + + if args.optimizers != "all": + names = {n.strip() for n in args.optimizers.split(",")} + opt_defs = [pair for pair in opt_defs if pair[0] in names] + + # Run per system + for sys_name, (rhs, x0) in systems_to_run: + print(f"\n=== System: {sys_name} ===") + t, X, dX_dt = rk4(rhs, x0, args.t0, args.t1, args.dt) + results = [] + for name, opt in opt_defs: + try: + res = run_optimizer(name, opt, X, args.dt, library, dX_dt) + except Exception as e: + res = {"name": name, "error": str(e)} + results.append(res) + + header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" + print(header) + print("-" * len(header)) + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print( + f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}" + ) + + # Select and print best optimizer by lowest MSE among successful runs + successful = [r for r in results if "error" not in r] + if successful: + best = min(successful, key=lambda r: r["mse"]) # type: ignore + print( + f"\n>>> Best optimizer: {best['name']} | Score={best['score']:.4f} | MSE={best['mse']:.4e} | Time={best['fit_time_s']:.4f}s | Complexity={best['complexity']}" + ) + eqs = best.get("equations") + if eqs: + print("Discovered equations:") + for eq in eqs: + print(f" {eq}") + else: + print("(equations unavailable)") + else: + print("\n>>> No successful optimizer runs for this system.") + + +if __name__ == "__main__": + main() diff --git a/examples/benchmarks/lorenz_benchmark.py b/examples/benchmarks/lorenz_benchmark.py deleted file mode 100644 index 0bd66898..00000000 --- a/examples/benchmarks/lorenz_benchmark.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3 -""" -Benchmark PySINDy optimizers on the Lorenz system. -Runs a quick RK4 simulation, constructs a SINDy model with a PolynomialLibrary, -then evaluates multiple optimizers (including TorchOptimizer if torch is available) -for runtime, model score, and sparsity (complexity). -""" -import time -import warnings -from typing import List, Tuple - -import numpy as np - -from pysindy import SINDy -from pysindy.feature_library import PolynomialLibrary -from pysindy.optimizers import ( - STLSQ, - SR3, - FROLS, - SSR, - WrappedOptimizer, -) - -# Optional optimizers guarded in __init__ (may be None if dependency missing) -try: - from pysindy.optimizers import TorchOptimizer # type: ignore -except Exception: - TorchOptimizer = None # type: ignore - -try: - from pysindy.optimizers import SBR # may require cvxpy or extra deps -except Exception: - SBR = None # type: ignore - - -def lorenz(t, x, sigma: float = 10.0, beta: float = 8.0 / 3.0, rho: float = 28.0): - u, v, w = x - up = -sigma * (u - v) - vp = rho * u - v - u * w - wp = -beta * w + u * v - return np.array([up, vp, wp], dtype=float) - - -def rk4_lorenz(t0: float = 0.0, t1: float = 10.0, dt: float = 0.01) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - t = np.arange(t0, t1 + 1e-12, dt) - X = np.zeros((t.size, 3), dtype=float) - X[0, :] = np.array([1.0, 1.0, 1.0], dtype=float) - # RK4 integrator - for i in range(1, t.size): - ti = t[i - 1] - xi = X[i - 1, :] - k1 = lorenz(ti, xi) - k2 = lorenz(ti + dt / 2.0, xi + dt * k1 / 2.0) - k3 = lorenz(ti + dt / 2.0, xi + dt * k2 / 2.0) - k4 = lorenz(ti + dt, xi + dt * k3) - X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) - - # Drop a short burn-in to avoid transient initialization effects - burn = int(1.0 / dt) # drop first 1s - t = t[burn:] - X = X[burn:, :] - - # Exact RHS at sample points using analytic lorenz - dX_dt = np.vstack([lorenz(tt, xx) for tt, xx in zip(t, X)]) - return t, X, dX_dt - - -def build_library() -> PolynomialLibrary: - # Standard SINDy polynomial library for Lorenz (degree 2 is typical) - return PolynomialLibrary(degree=2, include_interaction=True) - - -def run_optimizer(name: str, optimizer, X: np.ndarray, dt: float, library: PolynomialLibrary, dX_dt: np.ndarray): - model = SINDy(optimizer=optimizer, feature_library=library) - t0 = time.perf_counter() - model.fit(X, t=dt) - fit_time = time.perf_counter() - t0 - score = model.score(X, t=dt) - # Predict derivatives from learned model, compare to analytic RHS - dX_pred = model.predict(X) - mse = float(np.mean((dX_pred - dX_dt) ** 2)) - complexity = optimizer.complexity if hasattr(optimizer, "complexity") else None - # Gather equations as strings - try: - equations = model.equations() - except Exception: - equations = None - return { - "name": name, - "fit_time_s": fit_time, - "score": float(score), - "mse": mse, - "complexity": int(complexity) if complexity is not None else None, - "equations": equations, - } - - -def system_lorenz_example() -> None: - """Simulate the Lorenz attractor and test PySINDy + our solvers on it. - - Uses a simple RK4 integrator (no extra deps). To keep runtime reasonable, - solver options are reduced compared to defaults but can be adjusted. - """ - warnings.filterwarnings("ignore") - t, X, dX_dt = rk4_lorenz() - dt = t[1] - t[0] - library = build_library() - - optimizers = [] # type: List[Tuple[str, object]] - - # STLSQ: classic baseline - optimizers.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.05, max_iter=20))) - - # SR3: relaxed regularized regression, L0 prox behaves like hard thresholding - optimizers.append(("SR3-L0", SR3(reg_weight_lam=0.1, regularizer="L0", relax_coeff_nu=1.0, max_iter=50))) - - # FROLS & SSR: forward regression variants - optimizers.append(("FROLS", FROLS())) - optimizers.append(("SSR", SSR())) - - # SBR if available - if SBR is not None: - try: - optimizers.append(("SBR", SBR())) - except Exception: - pass - - # Torch-based optimizer if available - if TorchOptimizer is not None: - try: - optimizers.append(( - "TorchOptimizer", - TorchOptimizer( - seed=0, - ), - )) - except Exception: - pass - - results = [] - for name, opt in optimizers: - try: - res = run_optimizer(name, opt, X, dt, library, dX_dt) - except Exception as e: - res = {"name": name, "error": str(e)} - results.append(res) - - # Pretty print summary - header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" - print(header) - print("-" * len(header)) - for r in results: - if "error" in r: - print(f"{r['name']:<15} ERROR: {r['error']}") - else: - print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") - # Print discovered system equations - for r in results: - if "error" in r: - print(f"{r['name']:<15} ERROR: {r['error']}") - else: - print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") - eqs = r.get("equations") - if eqs: - for eq in eqs: - print(f" {eq}") - else: - print(" (equations unavailable)") - - -if __name__ == "__main__": - system_lorenz_example() diff --git a/examples/benchmarks/spring_benchmark.py b/examples/benchmarks/spring_benchmark.py deleted file mode 100644 index 7a64d9e9..00000000 --- a/examples/benchmarks/spring_benchmark.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python3 -""" -Benchmark PySINDy optimizers on a nonlinear spring system. -System: - xdot = v - vdot = -k * x - c * v + F * sin(x**2) -Simulates with RK4, evaluates multiple optimizers, prints metrics and equations, -and saves a plot comparing true vs predicted test trajectories. -""" -import time -import warnings -from typing import List, Tuple - -import numpy as np - -from pysindy import SINDy -from pysindy.feature_library import PolynomialLibrary -from pysindy.optimizers import ( - STLSQ, - SR3, - FROLS, - SSR, -) -# Optional optimizers -try: - from pysindy.optimizers import TorchOptimizer # type: ignore -except Exception: - TorchOptimizer = None # type: ignore - -try: - from pysindy.optimizers import SBR -except Exception: - SBR = None # type: ignore - - -def myspring(t, x, k=-4.518, c=0.372, F0=9.123): - """ - Example nonlinear dynamical system. - xdot = v - vdot = - k x - v c + F sin(x**2) - """ - return np.array([x[1], k * x[0] - c * x[1] + F0 * np.sin(x[0] ** 2)]) - - -def rk4_system(f, x0: np.ndarray, t0: float, t1: float, dt: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - t = np.arange(t0, t1 + 1e-12, dt) - X = np.zeros((t.size, x0.size), dtype=float) - X[0, :] = np.array(x0, dtype=float) - for i in range(1, t.size): - ti = t[i - 1] - xi = X[i - 1, :] - k1 = f(ti, xi) - k2 = f(ti + dt / 2.0, xi + dt * k1 / 2.0) - k3 = f(ti + dt / 2.0, xi + dt * k2 / 2.0) - k4 = f(ti + dt, xi + dt * k3) - X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) - dX_dt = np.vstack([f(tt, xx) for tt, xx in zip(t, X)]) - return t, X, dX_dt - - -def split(arr: np.ndarray, ratio: float) -> Tuple[np.ndarray, np.ndarray]: - n = arr.shape[0] - n_tr = int(np.floor(ratio * n)) - return arr[:n_tr], arr[n_tr:] - - -def build_library() -> PolynomialLibrary: - # Include sin via generalized or custom library? Approximate with polynomials up to degree 5 - return PolynomialLibrary(degree=5, include_interaction=True) - - -def run_optimizer(name: str, optimizer, X_tr: np.ndarray, dt: float, library: PolynomialLibrary, - X_te: np.ndarray, dX_dt_te: np.ndarray): - model = SINDy(optimizer=optimizer, feature_library=library) - t0 = time.perf_counter() - model.fit(X_tr, t=dt) - fit_time = time.perf_counter() - t0 - score = model.score(X_te, t=dt) - dX_pred_te = model.predict(X_te) - mse = float(np.mean((dX_pred_te - dX_dt_te) ** 2)) - complexity = getattr(optimizer, 'complexity', None) - try: - equations = model.equations() - except Exception: - equations = None - return { - "name": name, - "fit_time_s": fit_time, - "score": float(score), - "mse": mse, - "complexity": int(complexity) if complexity is not None else None, - "equations": equations, - "model": model, - } - - -def nonlinear_spring_example() -> None: - warnings.filterwarnings("ignore") - # Simulate - t0, t1, dt = 0.0, 10.0, 0.01 - x0 = np.array([0.4, 1.6], dtype=float) - t, X, dX_dt = rk4_system(myspring, x0, t0, t1, dt) - # Train/test split - ratio = 0.67 - X_tr, X_te = split(X, ratio) - dX_dt_tr, dX_dt_te = split(dX_dt, ratio) - # Build library - library = build_library() - - optimizers: List[Tuple[str, object]] = [] - optimizers.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.01, max_iter=30))) - optimizers.append(("SR3-L0", SR3(reg_weight_lam=0.05, regularizer="L0", relax_coeff_nu=1.0, max_iter=100))) - optimizers.append(("FROLS", FROLS())) - optimizers.append(("SSR", SSR())) - if SBR is not None: - try: - optimizers.append(("SBR", SBR())) - except Exception: - pass - if TorchOptimizer is not None: - try: - optimizers.append(( - "TorchOptimizer", - TorchOptimizer( - seed=0, - ), - )) - except Exception: - pass - - dt_scalar = dt - results = [] - for name, opt in optimizers: - try: - res = run_optimizer(name, opt, X_tr, dt_scalar, library, X_te, dX_dt_te) - except Exception as e: - res = {"name": name, "error": str(e)} - results.append(res) - - # Print summary - header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" - print(header) - print("-" * len(header)) - for r in results: - if "error" in r: - print(f"{r['name']:<15} ERROR: {r['error']}") - else: - print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") - for r in results: - if "error" in r: - print(f"{r['name']:<15} ERROR: {r['error']}") - else: - print(f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}") - eqs = r.get("equations") - if eqs: - for eq in eqs: - print(f" {eq}") - else: - print(" (equations unavailable)") - -if __name__ == "__main__": - nonlinear_spring_example() diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index 85313f34..2cef69c2 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -31,9 +31,8 @@ except (ImportError, NameError): pass try: - from .torch_solver import TorchOptimizer + from .torch_s import TorchOptimizer except Exception: - # Optional dependency torch may be missing or cause import-time errors TorchOptimizer = None diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_s.py similarity index 92% rename from pysindy/optimizers/torch_solver.py rename to pysindy/optimizers/torch_s.py index ea91c24f..439529f6 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_s.py @@ -29,7 +29,7 @@ >>> rng = np.random.default_rng(0) >>> X = rng.standard_normal((500, 3)) >>> Y = X @ np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]).T - >>> opt = TorchOptimizer(threshold=0.05, alpha_l1=1e-3, step_size=1e-2, max_iter=500, early_stopping_patience=50) + >>> opt = TorchOptimizer() >>> lib = PolynomialLibrary(degree=2) >>> model = SINDy(optimizer=opt, feature_library=lib) >>> model.fit(X, t=0.01) @@ -45,9 +45,9 @@ `min_delta` for `early_stopping_patience` consecutive steps. - The optimizer tracks and restores the best solution observed across iterations. """ - import warnings -from typing import Optional, TYPE_CHECKING +from typing import Optional +from typing import TYPE_CHECKING import numpy as np @@ -69,15 +69,18 @@ from torch import nn from torch.optim import Optimizer + # Taken from https://github.com/kyleliang919/C-Optim/blob/main/c_adamw.py class CAdamW(Optimizer): """ - Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Implements Adam algorithm with weight decay fix + as introduced in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). Parameters: params (`Iterable[nn.parameter.Parameter]`): - Iterable of parameters to optimize or dictionaries defining parameter groups. + Iterable of parameters to optimize or dictionaries + defining parameter groups. lr (`float`, *optional*, defaults to 0.001): The learning rate to use. betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): @@ -87,9 +90,11 @@ class CAdamW(Optimizer): weight_decay (`float`, *optional*, defaults to 0.0): Decoupled weight decay to apply. correct_bias (`bool`, *optional*, defaults to `True`): - Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + Whether or not to correct bias in Adam + (for instance, in Bert TF repository they use `False`). no_deprecation_warning (`bool`, *optional*, defaults to `False`): - A flag used to disable the deprecation warning (set to `True` to disable the warning). + A flag used to disable the deprecation warning + (set to `True` to disable the warning). """ def __init__( @@ -129,7 +134,8 @@ def step(self, closure: Callable = None): Performs a single optimization step. Arguments: - closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + closure (`Callable`, *optional*): + A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: @@ -178,10 +184,12 @@ def step(self, closure: Callable = None): # compute norm gradient mask = (exp_avg * grad > 0).to(grad.dtype) - # mask = mask * (mask.numel() / (mask.sum() + 1)) ## original implementation, leaving it here for record + # mask = mask * (mask.numel() / (mask.sum() + 1)) + # ## original implementation, leaving it here for record mask.div_( mask.mean().clamp_(min=1e-3) - ) # https://huggingface.co/rwightman/timm-optim-caution found this implementation is more favourable in many cases + ) # https://huggingface.co/rwightman/timm-optim-caution + # found this implementation is more favourable in many cases norm_grad = (exp_avg * mask) / denom p.add_(norm_grad, alpha=-step_size) return loss @@ -298,7 +306,7 @@ def __init__( alpha_l1: float = 0.0, step_size: float = 1e-1, max_iter: int = 1000, - optimizer: str = "adam", + optimizer: str = "cadamw", normalize_columns: bool = False, copy_X: bool = True, initial_guess: Optional[np.ndarray] = None, @@ -307,8 +315,8 @@ def __init__( seed: Optional[int] = None, sparse_ind: Optional[list[int]] = None, unbias: bool = True, - early_stopping_patience: int = 0, - min_delta: float = 0.0, + early_stopping_patience: int = 100, + min_delta: float = 1e-10, ): super().__init__( max_iter=max_iter, @@ -340,9 +348,12 @@ def __init__( self.early_stopping_patience = int(early_stopping_patience) self.min_delta = float(min_delta) if torch is None: - # Delay hard failure to fit-time to allow import of module without torch + # Delay hard failure to fit-time to + # allow import of module without torch warnings.warn( - "PyTorch is not installed; TorchSINDyOptimizer will not run until torch is available.") + "PyTorch is not installed; " + "TorchOptimizer will not run until torch is available." + ) def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: """Core optimization loop. @@ -372,7 +383,9 @@ def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: If PyTorch is not installed at run time. """ if torch is None: - raise ImportError("PyTorch is required for TorchSINDyOptimizer. Please install torch.") + raise ImportError( + "PyTorch is required for TorchOptimizer. Please install torch." + ) # Select device if self.torch_device == "cuda" and not torch.cuda.is_available(): warnings.warn("CUDA not available; falling back to CPU.") @@ -410,7 +423,9 @@ def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: # Support mask helper: restrict thresholding to specified indices sparse_mask = None if self.sparse_ind is not None: - sparse_mask = torch.zeros((n_targets, n_features), dtype=torch.bool, device=device) + sparse_mask = torch.zeros( + (n_targets, n_features), dtype=torch.bool, device=device + ) sparse_mask[:, self.sparse_ind] = True def loss_fn(W_): @@ -467,7 +482,9 @@ def loss_fn(W_): last_mask = mask if 0 < self.early_stopping_patience <= patience_counter: break - if self.verbose and (it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1): + if self.verbose and ( + it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1 + ): mse_val = float(((X @ W.T - Y).pow(2)).sum().cpu().numpy()) / n_samples l0 = int((torch.abs(W) >= self.threshold).sum().item()) print(f"[TorchSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") From 9cf5669948fb01ebc44ff64978fd1050a83703c5 Mon Sep 17 00:00:00 2001 From: hvoss Date: Thu, 4 Dec 2025 16:40:29 +0100 Subject: [PATCH 04/18] renamed back to torch_solver --- pysindy/optimizers/__init__.py | 2 +- pysindy/optimizers/{torch_s.py => torch_solver.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename pysindy/optimizers/{torch_s.py => torch_solver.py} (100%) diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index 2cef69c2..03dc6bf7 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -31,7 +31,7 @@ except (ImportError, NameError): pass try: - from .torch_s import TorchOptimizer + from .torch_solver import TorchOptimizer except Exception: TorchOptimizer = None diff --git a/pysindy/optimizers/torch_s.py b/pysindy/optimizers/torch_solver.py similarity index 100% rename from pysindy/optimizers/torch_s.py rename to pysindy/optimizers/torch_solver.py From bd594dff3cd55885b629f7670d92a69bb5a8e581 Mon Sep 17 00:00:00 2001 From: hvoss Date: Thu, 4 Dec 2025 16:42:52 +0100 Subject: [PATCH 05/18] test fix --- test/test_optimizers/test_torch_optimizer.py | 25 ++++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/test/test_optimizers/test_torch_optimizer.py b/test/test_optimizers/test_torch_optimizer.py index 13708e46..8527adb9 100644 --- a/test/test_optimizers/test_torch_optimizer.py +++ b/test/test_optimizers/test_torch_optimizer.py @@ -1,11 +1,11 @@ import numpy as np import pytest -torch = pytest.importorskip("torch") - -from pysindy.optimizers import TorchOptimizer from pysindy import SINDy from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import TorchOptimizer + +torch = pytest.importorskip("torch") def make_synthetic(n_samples=200, noise=0.0, seed=0): @@ -28,7 +28,9 @@ def test_basic_fit_shapes(): def test_unbias_and_sparsity(): X, Y = make_synthetic(noise=0.01) - opt = TorchOptimizer(max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True) + opt = TorchOptimizer( + max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True + ) opt.fit(X, Y) # Check some sparsity present assert np.count_nonzero(opt.coef_) < opt.coef_.size @@ -36,17 +38,14 @@ def test_unbias_and_sparsity(): def test_multi_target_with_sindylib(): # Integration with SINDy and a small library - rng = np.random.default_rng(0) t = np.linspace(0, 1, 200) - x = np.stack([ - np.sin(2*np.pi*t), - np.cos(2*np.pi*t), - 0.5*np.sin(4*np.pi*t) - ], axis=1) + x = np.stack( + [np.sin(2 * np.pi * t), np.cos(2 * np.pi * t), 0.5 * np.sin(4 * np.pi * t)], + axis=1, + ) lib = PolynomialLibrary(degree=2) opt = TorchOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=0) model = SINDy(optimizer=opt, feature_library=lib) - model.fit(x, t=t[1]-t[0]) - score = model.score(x, t=t[1]-t[0]) + model.fit(x, t=t[1] - t[0]) + score = model.score(x, t=t[1] - t[0]) assert score > 0.8 - From 874ce0855ad5df863559289a2d91246a18308c87 Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 11:59:00 +0100 Subject: [PATCH 06/18] added torch as optional dependency --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a018af80..44b2a6a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ dependencies = [ "derivative>=0.6.2", "scipy", "typing_extensions", - "torch", ] [project.optional-dependencies] @@ -50,7 +49,8 @@ dev = [ "jupytext", "pre-commit", "hypothesis", - "jupyter-contrib-nbextensions" + "jupyter-contrib-nbextensions", + "torch", ] docs = [ "ipython", From 4def959988025881b1b82904ed2eae7bd316e0b4 Mon Sep 17 00:00:00 2001 From: Hendric Voss <37121894+hvoss-techfak@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:55:39 +0100 Subject: [PATCH 07/18] Update examples/benchmarks/benchmark.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/benchmarks/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py index 41854aca..2e60f4b9 100644 --- a/examples/benchmarks/benchmark.py +++ b/examples/benchmarks/benchmark.py @@ -355,7 +355,7 @@ def main(): alpha_l1=1e-3, step_size=1e-2, max_iter=200, - optimizer="adam", + optimizer="cadamw", seed=0, unbias=True, early_stopping_patience=50, From adf8a59ecf9832b7e8e9fb03065718e4fae383ee Mon Sep 17 00:00:00 2001 From: Hendric Voss <37121894+hvoss-techfak@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:56:49 +0100 Subject: [PATCH 08/18] Update pysindy/optimizers/torch_solver.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pysindy/optimizers/torch_solver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index 439529f6..d8a526dc 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -59,7 +59,6 @@ torch = None # type: ignore if TYPE_CHECKING: # only for type checkers - import torch as torch_types # noqa: F401 import math From 424483bf82f0756e7e2dbf8769f74d8a9c8e5ed5 Mon Sep 17 00:00:00 2001 From: Hendric Voss <37121894+hvoss-techfak@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:01:15 +0100 Subject: [PATCH 09/18] Update pysindy/optimizers/torch_solver.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pysindy/optimizers/torch_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index d8a526dc..5edd244a 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -14,7 +14,7 @@ - Compatible with PySINDy BaseOptimizer interface and ensembling. Optional dependencies -- PyTorch is optional at import time; a RuntimeError will be raised during fit if +- PyTorch is optional at import time; an ImportError will be raised during fit if torch is not available. Code paths and annotations avoid import-time failures. - CAdamW is a custom optimizer shipped with the project and used when requested. From 31db7a570008988ad794929fc97012c4ceea21cb Mon Sep 17 00:00:00 2001 From: Hendric Voss <37121894+hvoss-techfak@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:03:57 +0100 Subject: [PATCH 10/18] Update examples/benchmarks/benchmark.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/benchmarks/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py index 2e60f4b9..68a50651 100644 --- a/examples/benchmarks/benchmark.py +++ b/examples/benchmarks/benchmark.py @@ -221,7 +221,7 @@ def rk4( - X: state trajectory of shape (N, d). - dX_dt: analytic RHS evaluated along trajectory, shape (N, d). """ - t = np.arange(t0, t1 + 1e-12, dt) + t = np.linspace(t0, t1, int(round((t1 - t0) / dt)) + 1) X = np.zeros((t.size, x0.size), dtype=float) X[0, :] = np.array(x0, dtype=float) for i in range(1, t.size): From 9c97a4be20d8005037c8b1bcb9330092fdf718ff Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:07:36 +0100 Subject: [PATCH 11/18] added multiple new tests and removed the magic number --- pysindy/optimizers/torch_solver.py | 19 ++-- test/test_optimizers/test_torch_optimizer.py | 100 +++++++++++++++++++ 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index d8a526dc..7eddc8ce 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -47,7 +47,6 @@ """ import warnings from typing import Optional -from typing import TYPE_CHECKING import numpy as np @@ -58,13 +57,9 @@ except Exception: # pragma: no cover - optional dependency torch = None # type: ignore -if TYPE_CHECKING: # only for type checkers - - import math from typing import Callable, Iterable, Tuple -import torch from torch import nn from torch.optim import Optimizer @@ -346,6 +341,8 @@ def __init__( self.sparse_ind = sparse_ind self.early_stopping_patience = int(early_stopping_patience) self.min_delta = float(min_delta) + self.stability_eps = 1e-14 + if torch is None: # Delay hard failure to fit-time to # allow import of module without torch @@ -406,7 +403,8 @@ def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: if self.coef_ is None: W = torch.zeros((n_targets, n_features), dtype=dtype, device=device) else: - W = torch.as_tensor(self.coef_, dtype=dtype, device=device) + # Decouple from numpy memory to avoid mutating BaseOptimizer.history_[0] + W = torch.tensor(self.coef_, dtype=dtype, device=device).clone() W.requires_grad_(True) # Optimizer for smooth loss @@ -475,7 +473,7 @@ def loss_fn(W_): with torch.no_grad(): coef_np = W.detach().cpu().numpy() self.history_.append(coef_np) - mask = np.abs(coef_np) >= max(self.threshold, 1e-14) + mask = np.abs(coef_np) >= max(self.threshold, self.stability_eps) if last_mask is not None and np.array_equal(mask, last_mask): break last_mask = mask @@ -484,7 +482,10 @@ def loss_fn(W_): if self.verbose and ( it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1 ): - mse_val = float(((X @ W.T - Y).pow(2)).sum().cpu().numpy()) / n_samples + mse_val = ( + float(((X @ W.T - Y).pow(2)).sum().cpu().detach().numpy()) + / n_samples + ) l0 = int((torch.abs(W) >= self.threshold).sum().item()) print(f"[TorchSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") @@ -492,7 +493,7 @@ def loss_fn(W_): final_W = (best_W if best_W is not None else W).detach().cpu().numpy() self.coef_ = final_W # ind_ based on tiny threshold - self.ind_ = np.abs(self.coef_) > 1e-14 + self.ind_ = np.abs(self.coef_) > self.stability_eps @property def complexity(self): diff --git a/test/test_optimizers/test_torch_optimizer.py b/test/test_optimizers/test_torch_optimizer.py index 8527adb9..6fe85477 100644 --- a/test/test_optimizers/test_torch_optimizer.py +++ b/test/test_optimizers/test_torch_optimizer.py @@ -49,3 +49,103 @@ def test_multi_target_with_sindylib(): model.fit(x, t=t[1] - t[0]) score = model.score(x, t=t[1] - t[0]) assert score > 0.8 + + +# New tests for broader coverage + + +def test_sparse_ind_hard_thresholding_effect(): + # With extremely large threshold and one iteration, only columns listed in + # sparse_ind should be hard-thresholded to zero; others should retain small + # non-zero values from the single gradient step. + X, Y = make_synthetic(n_samples=120, noise=0.0, seed=42) + thr = 1e9 + opt = TorchOptimizer( + max_iter=1, + step_size=1e-3, + threshold=thr, + alpha_l1=0.0, + seed=0, + sparse_ind=[0], # only first feature is forced to zero by hard threshold + ) + opt.fit(X, Y) + # Column 0 should be exactly zero (hard-thresholded) + assert np.allclose(opt.coef_[:, 0], 0.0) + # At least one coefficient outside column 0 should remain non-zero + assert np.any(np.abs(opt.coef_[:, 1:]) > 0.0) + + +@pytest.mark.parametrize("opt_name", ["sgd", "adam", "adamw", "cadamw"]) +def test_optimizer_variants_run(opt_name): + X, Y = make_synthetic(n_samples=100, noise=0.01, seed=3) + opt = TorchOptimizer( + optimizer=opt_name, max_iter=30, threshold=1e-3, alpha_l1=1e-4, seed=1 + ) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + + +def test_early_stopping_via_patience_and_min_delta(): + # Use a huge min_delta so improvements never count; should stop after + # early_stopping_patience iterations instead of max_iter. + X, Y = make_synthetic(n_samples=150, noise=0.01, seed=4) + patience = 2 + opt = TorchOptimizer( + max_iter=200, + threshold=0.0, + alpha_l1=0.0, + seed=0, + early_stopping_patience=patience, + min_delta=1e9, + ) + opt.fit(X, Y) + # history_ includes the initial entry (from BaseOptimizer) + per-iteration appends + # Ensure we stopped well before max_iter due to patience + assert len(opt.history_) <= 1 + patience + 1 # a few iterations at most + + +def test_cuda_device_selection_warning_or_success(): + X, Y = make_synthetic(n_samples=50, noise=0.0, seed=5) + if not torch.cuda.is_available(): + with pytest.warns(UserWarning, match="CUDA not available; falling back to CPU"): + opt = TorchOptimizer(device="cuda", max_iter=5, threshold=0.0, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + else: + opt = TorchOptimizer(device="cuda", max_iter=10, threshold=1e-4, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + + +def test_complexity_property_matches_manual(): + X, Y = make_synthetic(n_samples=120, noise=0.02, seed=6) + opt = TorchOptimizer(max_iter=60, threshold=5e-2, alpha_l1=1e-3, seed=6) + opt.fit(X, Y) + manual = np.count_nonzero(opt.coef_) + np.count_nonzero(opt.intercept_) + assert opt.complexity == manual + + +def test_verbose_prints_progress(capsys): + X, Y = make_synthetic(n_samples=60, noise=0.0, seed=7) + opt = TorchOptimizer(max_iter=5, threshold=1e-4, alpha_l1=0.0, verbose=True, seed=0) + opt.fit(X, Y) + captured = capsys.readouterr().out + assert "[TorchSINDy]" in captured and "iter=" in captured + + +def test_history_tracking_shapes_and_length(): + X, Y = make_synthetic(n_samples=80, noise=0.0, seed=8) + # Provide a zero initial guess so the first history entry (initial) certainly + # differs from later iterates after optimization updates. + init = np.zeros((Y.shape[1], X.shape[1])) + opt = TorchOptimizer( + max_iter=8, threshold=1e-4, alpha_l1=1e-4, seed=0, initial_guess=init + ) + opt.fit(X, Y) + # History should contain the initial coefficients + at least one update + assert isinstance(opt.history_, list) + assert len(opt.history_) >= 2 + shapes_ok = all(h.shape == opt.coef_.shape for h in opt.history_) + assert shapes_ok + # The first history entry (initial guess) should differ from a later iterate + assert not np.allclose(opt.history_[0], opt.history_[-1]) From fd330f3f41faf788259402504ad97348faee06d0 Mon Sep 17 00:00:00 2001 From: Hendric Voss <37121894+hvoss-techfak@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:08:30 +0100 Subject: [PATCH 12/18] Update pysindy/optimizers/torch_solver.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pysindy/optimizers/torch_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index 5edd244a..85ce871e 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -312,7 +312,7 @@ def __init__( verbose: bool = False, device: Optional[str] = None, seed: Optional[int] = None, - sparse_ind: Optional[list[int]] = None, + sparse_ind: Optional[List[int]] = None, unbias: bool = True, early_stopping_patience: int = 100, min_delta: float = 1e-10, From d277af78ddf865167f4b79274bcebaeca82a39f8 Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:12:37 +0100 Subject: [PATCH 13/18] fixed missing import --- pysindy/optimizers/torch_solver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index 7b094ba7..bc56ef87 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -46,6 +46,7 @@ - The optimizer tracks and restores the best solution observed across iterations. """ import warnings +from typing import List from typing import Optional import numpy as np From 7fa86045c77a963d01fb5a8b44b1000b03ac8fc2 Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:16:45 +0100 Subject: [PATCH 14/18] fixed pass in except block --- examples/benchmarks/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py index 68a50651..0b3facd9 100644 --- a/examples/benchmarks/benchmark.py +++ b/examples/benchmarks/benchmark.py @@ -344,7 +344,7 @@ def main(): try: opt_defs.append(("SBR", SBR())) except Exception: - pass + traceback.print_exc() if TorchOptimizer is not None: try: opt_defs.append( From 187065731c95373c8cf7e99b70bf4181ff575d1f Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:18:24 +0100 Subject: [PATCH 15/18] fixed pass in except block --- pysindy/optimizers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index 03dc6bf7..af94cb27 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -33,7 +33,7 @@ try: from .torch_solver import TorchOptimizer except Exception: - TorchOptimizer = None + pass __all__ = [ From e7604fdeac4aaf68a03d5dff3912610cf684cd52 Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:20:50 +0100 Subject: [PATCH 16/18] fixed pass in except block --- pysindy/optimizers/torch_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py index bc56ef87..e5177abb 100644 --- a/pysindy/optimizers/torch_solver.py +++ b/pysindy/optimizers/torch_solver.py @@ -464,7 +464,7 @@ def loss_fn(W_): # Evaluate objective and update best with torch.no_grad(): obj = float(loss_fn(W).cpu().numpy()) - if best_obj is None or (best_obj - obj) > self.min_delta: + if best_obj is None or obj < (best_obj - self.min_delta): best_obj = obj best_W = W.detach().clone() patience_counter = 0 From f3c3f9c664e224c6d25963f4a1fd0ba5f8fe704d Mon Sep 17 00:00:00 2001 From: hvoss Date: Fri, 5 Dec 2025 16:22:10 +0100 Subject: [PATCH 17/18] fixed variable --- examples/benchmarks/benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py index 0b3facd9..06c036ec 100644 --- a/examples/benchmarks/benchmark.py +++ b/examples/benchmarks/benchmark.py @@ -266,8 +266,8 @@ def run_optimizer( model.fit(X, t=dt) fit_time = time.perf_counter() - t0 score = model.score(X, t=dt) - dX_pred = model.predict(X) - mse = float(np.mean((dX_pred - dX_dt) ** 2)) + dXdt_pred = model.predict(X) + mse = float(np.mean((dXdt_pred - dX_dt) ** 2)) complexity = optimizer.complexity if hasattr(optimizer, "complexity") else None try: equations = model.equations() From 37e007d86d53386fdafc9e98c44ec4c006aaa1d1 Mon Sep 17 00:00:00 2001 From: hvoss Date: Sat, 6 Dec 2025 14:24:01 +0100 Subject: [PATCH 18/18] Added jax implementation of the solver and wired it into everything. Also added new tests for the jax implementation. --- examples/benchmarks/benchmark.py | 42 ++- pysindy/optimizers/__init__.py | 5 + pysindy/optimizers/jax_solver.py | 294 +++++++++++++++++++ test/test_optimizers/test_jax_optimizer.py | 131 +++++++++ test/test_optimizers/test_torch_optimizer.py | 26 +- 5 files changed, 483 insertions(+), 15 deletions(-) create mode 100644 pysindy/optimizers/jax_solver.py create mode 100644 test/test_optimizers/test_jax_optimizer.py diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py index 06c036ec..75652835 100644 --- a/examples/benchmarks/benchmark.py +++ b/examples/benchmarks/benchmark.py @@ -31,7 +31,9 @@ import numpy as np +from pysindy import ConstrainedSR3 from pysindy import SINDy +from pysindy import StableLinearSR3 from pysindy.feature_library import PolynomialLibrary from pysindy.optimizers import FROLS from pysindy.optimizers import SR3 @@ -46,6 +48,10 @@ from pysindy.optimizers import TorchOptimizer except Exception: TorchOptimizer = None # type: ignore +try: + from pysindy.optimizers import JaxOptimizer +except Exception: + JaxOptimizer = None # type: ignore # ------------------------------- Systems ------------------------------------ @@ -331,11 +337,23 @@ def main(): # Select optimizers opt_defs: List[Tuple[str, object]] = [] - opt_defs.append(("STLSQ", STLSQ(threshold=0.1, alpha=0.05, max_iter=20))) + opt_defs.append(("STLSQ", STLSQ())) opt_defs.append( ( "SR3-L0", - SR3(reg_weight_lam=0.1, regularizer="L0", relax_coeff_nu=1.0, max_iter=50), + SR3(), + ) + ) + opt_defs.append( + ( + "SR3-constrained", + ConstrainedSR3(), + ) + ) + opt_defs.append( + ( + "SR3-stable", + StableLinearSR3(), ) ) opt_defs.append(("FROLS", FROLS())) @@ -365,6 +383,26 @@ def main(): ) except Exception: traceback.print_exc() + if JaxOptimizer is not None: + try: + opt_defs.append( + ( + "JaxOptimizer", + JaxOptimizer( + threshold=0.05, + alpha_l1=1e-3, + step_size=1e-2, + max_iter=200, + optimizer="adam", + seed=0, + unbias=True, + early_stopping_patience=50, + min_delta=1e-8, + ), + ) + ) + except Exception: + traceback.print_exc() if args.optimizers != "all": names = {n.strip() for n in args.optimizers.split(",")} diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index af94cb27..f7e2af64 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -34,6 +34,10 @@ from .torch_solver import TorchOptimizer except Exception: pass +try: + from .jax_solver import JaxOptimizer +except Exception: + pass __all__ = [ @@ -51,4 +55,5 @@ "MIOSR", "SBR", "TorchOptimizer", + "JaxOptimizer", ] diff --git a/pysindy/optimizers/jax_solver.py b/pysindy/optimizers/jax_solver.py new file mode 100644 index 00000000..29a57cc5 --- /dev/null +++ b/pysindy/optimizers/jax_solver.py @@ -0,0 +1,294 @@ +""" +JAX-based SINDy optimizer using proximal gradient + iterative hard-thresholding. + +This module provides a high-performance optimizer implemented with JAX to solve +Sparse Identification of Nonlinear Dynamics (SINDy) regression problems. It +minimizes a smooth data-fit term and applies proximal/thresholding operations to +promote sparsity in the discovered dynamical system. + +Key features +- Batched multi-target regression on CPU or GPU/TPU (if available). +- Proximal L1 shrinkage and hard thresholding to encourage sparse models. +- Optimizers: SGD, Adam, AdamW with minimal implementations in JAX. +- Note: The CAdamW variant is currently not available in this implementation. +- Best-solution tracking across iterations and early-stopping support. +- Compatible with PySINDy BaseOptimizer interface and ensembling. + +Optional dependencies +- JAX is optional at import time; an ImportError will be raised during fit if + jax is not available. Code paths and annotations avoid import-time failures. + +Notes +----- +- Thresholding and proximal operations operate on coefficient magnitudes. A small + numerical threshold (1e-14) is used to derive support masks for `ind_`. +- When `sparse_ind` is provided, thresholding affects only the specified columns. +- Early stopping halts iterations when the objective fails to improve by at least + `min_delta` for `early_stopping_patience` consecutive steps. +- The optimizer tracks and restores the best solution observed across iterations. +""" +import warnings +from typing import List +from typing import Optional + +import numpy as np + +from .base import BaseOptimizer + +try: + import jax # type: ignore + import jax.numpy as jnp # type: ignore +except Exception: # pragma: no cover - optional dependency + jax = None # type: ignore + jnp = None # type: ignore + + +def _soft_threshold(t: jnp.ndarray, lam: float): + if lam <= 0: + return t + return jnp.sign(t) * jnp.maximum(jnp.abs(t) - lam, 0.0) + + +def _hard_threshold(t: jnp.ndarray, thr: float): + if thr <= 0: + return t + return t * (jnp.abs(t) >= thr) + + +class JaxOptimizer(BaseOptimizer): + """JAX-powered optimizer for sparse SINDy regression. + + Objective + J(W) = (1/N) * ||Y - X W^T||_F^2 + alpha_l1 * ||W||_1 + + Parameters + ---------- + threshold : float, default 1e-1 + Minimum magnitude for a coefficient. Values with |coef| < threshold + are set to zero after each iteration (hard-threshold). + alpha_l1 : float, default 0.0 + L1 penalty weight. If > 0, soft-thresholding is applied after the + gradient step to shrink coefficients. + step_size : float, default 1e-1 + Learning rate for the chosen Torch optimizer. + max_iter : int, default 1000 + Maximum number of iterations. + optimizer : {"sgd", "adam", "adamw", "cadamw"}, default "adam" + Which optimizer to use for the smooth part of the objective. + normalize_columns : bool, default False + See BaseOptimizer; if True, columns of X are normalized before fitting. + copy_X : bool, default True + See BaseOptimizer; controls whether X is copied or may be overwritten. + initial_guess : np.ndarray or None, default None + Warm-start coefficients; shape (n_targets, n_features). + verbose : bool, default False + If True, prints periodic progress including loss and sparsity. + device : {"cpu", "cuda"} or None, default None + Torch device to use. If None, uses CPU; if "cuda" is requested but not + available, falls back to CPU with a warning. + seed : int or None, default None + Random seed for reproducibility of Torch and NumPy. + sparse_ind : list[int] or None, default None + If provided, thresholding only applies to these feature indices. Other + indices remain unaffected by hard-thresholding. + unbias : bool, default True + See BaseOptimizer; when True, performs an unbiased refit on the selected + support after optimization. + early_stopping_patience : int, default 0 + If > 0, stop early when the objective has not improved by at least + `min_delta` for this many consecutive iterations. + min_delta : float, default 0.0 + Minimum improvement to reset patience; small positive values help + prevent stopping on floating-point noise. + Attributes + ---------- + coef_ : np.ndarray, shape (n_targets, n_features) + Optimized SINDy coefficients. + history_ : list of np.ndarray + Coefficient history; `history_[k]` is the coefficient matrix after + iteration k. + ind_ : np.ndarray, shape (n_targets, n_features) + Boolean mask of nonzero coefficients (|coef| > 1e-14). + Examples + -------- + >>> import numpy as np + >>> from scipy.integrate import odeint + >>> from pysindy import SINDy + >>> from pysindy.optimizers import JaxOptimizer + >>> lorenz = lambda z,t : [10 * (z[1] - z[0]), + >>> z[0] * (28 - z[2]) - z[1], + >>> z[0] * z[1] - 8 / 3 * z[2]] + >>> t = np.arange(0, 2, .002) + >>> x = odeint(lorenz, [-8, 8, 27], t) + >>> opt = JaxOptimizer(threshold=.1, alpha_l1=.01, max_iter=1000) + >>> model = SINDy(optimizer=opt) + >>> model.fit(x, t=t[1] - t[0]) + >>> model.print() + + x0' = -9.973 x0 + 9.973 x1 + x1' = -0.129 1 + 27.739 x0 + -0.949 x1 + -0.993 x0 x2 + x2' = -2.656 x2 + 0.996 x0 x1 + """ + + def __init__( + self, + threshold: float = 1e-1, + alpha_l1: float = 0.0, + step_size: float = 1e-1, + max_iter: int = 1000, + optimizer: str = "adam", + normalize_columns: bool = False, + copy_X: bool = True, + initial_guess: Optional[np.ndarray] = None, + verbose: bool = False, + seed: Optional[int] = None, + sparse_ind: Optional[List[int]] = None, + unbias: bool = True, + early_stopping_patience: int = 100, + min_delta: float = 1e-10, + ): + super().__init__( + max_iter=max_iter, + normalize_columns=normalize_columns, + initial_guess=initial_guess, + copy_X=copy_X, + unbias=unbias, + ) + if threshold < 0: + raise ValueError("threshold cannot be negative") + if alpha_l1 < 0: + raise ValueError("alpha_l1 cannot be negative") + if step_size <= 0: + raise ValueError("step_size must be positive") + if optimizer not in ("sgd", "adam", "adamw"): + raise ValueError("optimizer must be 'sgd', 'adam', or 'adamw'") + if early_stopping_patience < 0: + raise ValueError("early_stopping_patience cannot be negative") + if min_delta < 0: + raise ValueError("min_delta cannot be negative") + self.threshold = float(threshold) + self.alpha_l1 = float(alpha_l1) + self.step_size = float(step_size) + self.verbose = bool(verbose) + self.seed = seed + self.opt_name = optimizer + self.sparse_ind = sparse_ind + self.early_stopping_patience = int(early_stopping_patience) + self.min_delta = float(min_delta) + self.stability_eps = 1e-14 + + if jax is None: + warnings.warn( + "JAX is not installed; " + "JaxOptimizer will not run until jax is available." + ) + + def _reduce(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None: + if jax is None: + raise ImportError("JAX is required for JaxOptimizer. Please install jax.") + + # Seed control + if self.seed is not None: + # jax uses PRNG keys. + # We just seed numpy for determinism in thresholds/history + np.random.seed(self.seed) + + X = jnp.asarray(x, dtype=jnp.float64) + Y = jnp.asarray(y, dtype=jnp.float64) + n_samples, n_features = X.shape + n_targets = Y.shape[1] + + if self.coef_ is None: + W = jnp.zeros((n_targets, n_features), dtype=jnp.float64) + else: + W = jnp.asarray(self.coef_, dtype=jnp.float64) + + # sparse mask + sparse_mask = None + if self.sparse_ind is not None: + sparse_mask = jnp.zeros((n_targets, n_features), dtype=bool) + sparse_mask = sparse_mask.at[:, self.sparse_ind].set(True) + + def loss_fn(W_): + Y_pred = X @ W_.T + residual = Y_pred - Y + mse = jnp.sum(residual**2) / n_samples + l1 = self.alpha_l1 * jnp.sum(jnp.abs(W_)) if self.alpha_l1 > 0 else 0.0 + return mse + l1 + + grad_fn = jax.grad(lambda W_: loss_fn(W_)) + + # Optimizer states for Adam/AdamW + m = jnp.zeros_like(W) + v = jnp.zeros_like(W) + beta1, beta2, eps, wd = 0.9, 0.999, 1e-8, 0.0 + if self.opt_name == "adamw": + wd = 1e-4 # small default weight decay + + last_mask = None + best_obj = None + best_W = None + patience_counter = 0 + + def step_adam(W, m, v, g, t): + m = beta1 * m + (1 - beta1) * g + v = beta2 * v + (1 - beta2) * (g * g) + m_hat = m / (1 - beta1**t) + v_hat = v / (1 - beta2**t) + W = W - self.step_size * m_hat / (jnp.sqrt(v_hat) + eps) + if wd > 0: + W = W - self.step_size * wd * W + return W, m, v + + # Loop + for it in range(self.max_iter): + g = grad_fn(W) + if self.opt_name == "sgd": + W = W - self.step_size * g + else: + W, m, v = step_adam(W, m, v, g, it + 1) + + # proximal L1 + if self.alpha_l1 > 0: + W = _soft_threshold(W, self.alpha_l1 * self.step_size) + # hard threshold + if self.threshold > 0: + kept = _hard_threshold(W, self.threshold) + if sparse_mask is None: + W = kept + else: + W = jnp.where(sparse_mask, kept, W) + + # evaluate objective + obj = float(loss_fn(W)) + if best_obj is None or obj < (best_obj - self.min_delta): + best_obj = obj + best_W = W + patience_counter = 0 + else: + patience_counter += 1 + + # track history and early stop + coef_np = np.array(W) + self.history_.append(coef_np) + mask = np.abs(coef_np) >= max(self.threshold, self.stability_eps) + if last_mask is not None and np.array_equal(mask, last_mask): + break + last_mask = mask + if 0 < self.early_stopping_patience <= patience_counter: + break + + if self.verbose and ( + it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1 + ): + mse_val = float(np.sum((np.array(X @ W.T - Y)) ** 2) / n_samples) + l0 = int(np.sum(np.abs(coef_np) >= self.threshold)) + print(f"[JaxSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") + + final_W = np.array(best_W if best_W is not None else W) + self.coef_ = final_W + self.ind_ = np.abs(self.coef_) > self.stability_eps + + @property + def complexity(self): + return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) diff --git a/test/test_optimizers/test_jax_optimizer.py b/test/test_optimizers/test_jax_optimizer.py new file mode 100644 index 00000000..88ff8edc --- /dev/null +++ b/test/test_optimizers/test_jax_optimizer.py @@ -0,0 +1,131 @@ +import numpy as np +import pytest + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary + +jax = pytest.importorskip("jax") + +jax_available = True +try: + from pysindy.optimizers import JaxOptimizer +except Exception: + jax_available = False + + +def make_synthetic(n_samples=200, noise=0.0, seed=0): + rng = np.random.default_rng(seed) + X = rng.normal(size=(n_samples, 3)) + W = np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]) + Y = X @ W.T + noise * rng.normal(size=(n_samples, 3)) + return X, Y + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_basic_fit_shapes(): + X, Y = make_synthetic() + opt = JaxOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=1) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + assert opt.ind_.shape == (Y.shape[1], X.shape[1]) + assert len(opt.history_) >= 1 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_unbias_and_sparsity(): + X, Y = make_synthetic(noise=0.01) + opt = JaxOptimizer(max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True) + opt.fit(X, Y) + assert np.count_nonzero(opt.coef_) < opt.coef_.size + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_multi_target_with_sindylib(): + t = np.linspace(0, 1, 200) + x = np.stack( + [np.sin(2 * np.pi * t), np.cos(2 * np.pi * t), 0.5 * np.sin(4 * np.pi * t)], + axis=1, + ) + lib = PolynomialLibrary(degree=2) + opt = JaxOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=0) + model = SINDy(optimizer=opt, feature_library=lib) + model.fit(x, t=t[1] - t[0]) + score = model.score(x, t=t[1] - t[0]) + assert score > 0.8 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +@pytest.mark.parametrize("opt_name", ["sgd", "adam", "adamw"]) +def test_optimizer_variants_run(opt_name): + X, Y = make_synthetic(n_samples=100, noise=0.01, seed=3) + opt = JaxOptimizer( + optimizer=opt_name, max_iter=30, threshold=1e-3, alpha_l1=1e-4, seed=1 + ) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_early_stopping_via_patience_and_min_delta(): + X, Y = make_synthetic(n_samples=150, noise=0.01, seed=4) + patience = 2 + opt = JaxOptimizer( + max_iter=200, + threshold=0.0, + alpha_l1=0.0, + seed=0, + early_stopping_patience=patience, + min_delta=1e9, + ) + opt.fit(X, Y) + assert len(opt.history_) <= 1 + patience + 1 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_complexity_property_matches_manual(): + X, Y = make_synthetic(n_samples=120, noise=0.02, seed=6) + opt = JaxOptimizer(max_iter=60, threshold=5e-2, alpha_l1=1e-3, seed=6) + opt.fit(X, Y) + manual = np.count_nonzero(opt.coef_) + np.count_nonzero(opt.intercept_) + assert opt.complexity == manual + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_verbose_prints_progress(capsys): + X, Y = make_synthetic(n_samples=60, noise=0.0, seed=7) + opt = JaxOptimizer(max_iter=5, threshold=1e-4, alpha_l1=0.0, verbose=True, seed=0) + opt.fit(X, Y) + captured = capsys.readouterr().out + assert "[JaxSINDy]" in captured and "iter=" in captured + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_history_tracking_shapes_and_length(): + X, Y = make_synthetic(n_samples=80, noise=0.0, seed=8) + init = np.zeros((Y.shape[1], X.shape[1])) + opt = JaxOptimizer( + max_iter=8, threshold=1e-4, alpha_l1=1e-4, seed=0, initial_guess=init + ) + opt.fit(X, Y) + assert isinstance(opt.history_, list) + assert len(opt.history_) >= 2 + shapes_ok = all(h.shape == opt.coef_.shape for h in opt.history_) + assert shapes_ok + assert not np.allclose(opt.history_[0], opt.history_[-1]) + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_sparse_ind_hard_thresholding_effect(): + X, Y = make_synthetic(n_samples=120, noise=0.0, seed=42) + thr = 1e9 + opt = JaxOptimizer( + max_iter=1, + step_size=1e-3, + threshold=thr, + alpha_l1=0.0, + seed=0, + sparse_ind=[0], + ) + opt.fit(X, Y) + assert np.allclose(opt.coef_[:, 0], 0.0) + assert np.any(np.abs(opt.coef_[:, 1:]) > 0.0) diff --git a/test/test_optimizers/test_torch_optimizer.py b/test/test_optimizers/test_torch_optimizer.py index 6fe85477..f8a32c9a 100644 --- a/test/test_optimizers/test_torch_optimizer.py +++ b/test/test_optimizers/test_torch_optimizer.py @@ -104,19 +104,6 @@ def test_early_stopping_via_patience_and_min_delta(): assert len(opt.history_) <= 1 + patience + 1 # a few iterations at most -def test_cuda_device_selection_warning_or_success(): - X, Y = make_synthetic(n_samples=50, noise=0.0, seed=5) - if not torch.cuda.is_available(): - with pytest.warns(UserWarning, match="CUDA not available; falling back to CPU"): - opt = TorchOptimizer(device="cuda", max_iter=5, threshold=0.0, seed=0) - opt.fit(X, Y) - assert opt.coef_.shape == (Y.shape[1], X.shape[1]) - else: - opt = TorchOptimizer(device="cuda", max_iter=10, threshold=1e-4, seed=0) - opt.fit(X, Y) - assert opt.coef_.shape == (Y.shape[1], X.shape[1]) - - def test_complexity_property_matches_manual(): X, Y = make_synthetic(n_samples=120, noise=0.02, seed=6) opt = TorchOptimizer(max_iter=60, threshold=5e-2, alpha_l1=1e-3, seed=6) @@ -149,3 +136,16 @@ def test_history_tracking_shapes_and_length(): assert shapes_ok # The first history entry (initial guess) should differ from a later iterate assert not np.allclose(opt.history_[0], opt.history_[-1]) + + +def test_device_selection_warning_or_success(): + X, Y = make_synthetic(n_samples=50, noise=0.0, seed=5) + if not torch.cuda.is_available(): + with pytest.warns(UserWarning, match="CUDA not available; falling back to CPU"): + opt = TorchOptimizer(device="cuda", max_iter=5, threshold=0.0, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + else: + opt = TorchOptimizer(device="cuda", max_iter=10, threshold=1e-4, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1])