diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py new file mode 100644 index 00000000..75652835 --- /dev/null +++ b/examples/benchmarks/benchmark.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +""" +Generalized benchmark for PySINDy optimizers on multiple nonlinear systems. + +Features +- Unified runner for many ODE systems (Lorenz, Rossler, Van der Pol, Duffing, Chua, + Pendulum, Kuramoto, Hindmarsh–Rose (reduced), FitzHugh–Nagumo, Thomas, Sprott A, + and a nonlinear spring). +- Simple RK4 integrator (no external dependencies). +- Evaluates available optimizers (STLSQ, SR3, FROLS, SSR, optional SBR, optional TorchOptimizer). +- Reports runtime, score, MSE against analytic RHS, complexity, and discovered equations. +- CLI flags to choose system, integration params, library degree, and optimizers. +- Supports running all systems in a single invocation with --system all. + +Usage examples +- Run Lorenz with defaults: + python examples/benchmarks/benchmark.py --system lorenz --dt 0.01 --t1 10 +- List systems: + python examples/benchmarks/benchmark.py --list +- Run all systems on all solvers: + python examples/benchmarks/benchmark.py --system all --dt 0.01 --t1 5 +""" +import argparse +import time +import traceback +import warnings +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np + +from pysindy import ConstrainedSR3 +from pysindy import SINDy +from pysindy import StableLinearSR3 +from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import FROLS +from pysindy.optimizers import SR3 +from pysindy.optimizers import SSR +from pysindy.optimizers import STLSQ + +try: + from pysindy.optimizers import SBR +except Exception: + SBR = None # type: ignore +try: + from pysindy.optimizers import TorchOptimizer +except Exception: + TorchOptimizer = None # type: ignore +try: + from pysindy.optimizers import JaxOptimizer +except Exception: + JaxOptimizer = None # type: ignore + + +# ------------------------------- Systems ------------------------------------ + + +def lorenz_rhs(t, x, sigma=10.0, beta=8.0 / 3.0, rho=28.0): + u, v, w = x + return np.array( + [ + -sigma * (u - v), + rho * u - v - u * w, + -beta * w + u * v, + ], + dtype=float, + ) + + +def rossler_rhs(t, x, a=0.2, b=0.2, c=5.7): + x1, x2, x3 = x + return np.array( + [ + -(x2 + x3), + x1 + a * x2, + b + x3 * (x1 - c), + ], + dtype=float, + ) + + +def vdp_rhs(t, x, mu=3.0): + x1, x2 = x + return np.array( + [ + x2, + mu * (1 - x1**2) * x2 - x1, + ], + dtype=float, + ) + + +def duffing_rhs(t, x, delta=0.2, alpha=-1.0, beta=1.0, gamma=0.3, omega=1.2): + # Unforced variant (set gamma=0) or keep small forcing + x1, x2 = x + return np.array( + [ + x2, + -delta * x2 - alpha * x1 - beta * x1**3 + gamma * np.cos(omega * t), + ], + dtype=float, + ) + + +def chua_rhs(t, x, alpha=15.6, beta=28.0, m0=-1.143, m1=-0.714): + x1, x2, x3 = x + h = m1 * x1 + 0.5 * (m0 - m1) * (np.abs(x1 + 1) - np.abs(x1 - 1)) + return np.array( + [ + alpha * (x2 - x1 - h), + x1 - x2 + x3, + -beta * x2, + ], + dtype=float, + ) + + +def pendulum_rhs(t, x, g=9.81, L=1.0, q=0.0, F=0.0, omega_d=0.0): + # Damped/forced nonlinear pendulum; default undamped, unforced + theta, omega = x + return np.array( + [ + omega, + -(g / L) * np.sin(theta) - q * omega + F * np.sin(omega_d * t), + ], + dtype=float, + ) + + +def kuramoto_rhs(t, x, K=0.5): + # Small network (3 oscillators) with identical natural freq=0 + # x are phases; coupling via sine differences + n = x.shape[0] + dx = np.zeros(n) + for i in range(n): + dx[i] = (K / n) * np.sum(np.sin(x - x[i])) + return dx + + +def hindmarsh_rose_rhs(t, x): + # Reduced 2D form for benchmarking + x1, x2 = x + a = 1.0 + b = 3.0 + c = 1.0 + d = 5.0 + return np.array( + [ + x2 - a * x1**3 + b * x1**2 + 1.0, # simplified variant + c - d * x1**2 - x2, # simplified variant + ], + dtype=float, + ) + + +def fitzhugh_nagumo_rhs(t, x, a=0.7, b=0.8, tau=12.5, Ia=0.5): + v, w = x + return np.array( + [ + v - v**3 / 3 - w + Ia, + (v + a - b * w) / tau, + ], + dtype=float, + ) + + +def thomas_rhs(t, x, b=0.208186): + x1, x2, x3 = x + return np.array( + [ + np.sin(x2) - b * x1, + np.sin(x3) - b * x2, + np.sin(x1) - b * x3, + ], + dtype=float, + ) + + +def sprott_a_rhs(t, x): + x1, x2, x3 = x + return np.array( + [ + x2, + x3, + -x1 + x2**2 - x3, + ], + dtype=float, + ) + + +def myspring_rhs(t, x, k=-4.518, c=0.372, F0=9.123): + return np.array([x[1], k * x[0] - c * x[1] + F0 * np.sin(x[0] ** 2)], dtype=float) + + +SYSTEMS: Dict[str, Tuple[Callable, np.ndarray]] = { + "lorenz": (lorenz_rhs, np.array([1.0, 1.0, 1.0], dtype=float)), + "rossler": (rossler_rhs, np.array([0.1, 0.1, 0.1], dtype=float)), + "vanderpol": (vdp_rhs, np.array([2.0, 0.0], dtype=float)), + "duffing": (duffing_rhs, np.array([0.1, 0.0], dtype=float)), + "chua": (chua_rhs, np.array([0.1, 0.0, 0.0], dtype=float)), + "pendulum": (pendulum_rhs, np.array([0.5, 0.0], dtype=float)), + "kuramoto3": (kuramoto_rhs, np.array([0.2, -0.3, 0.1], dtype=float)), + "hindmarsh_rose2": (hindmarsh_rose_rhs, np.array([0.0, 0.0], dtype=float)), + "fitzhugh_nagumo": (fitzhugh_nagumo_rhs, np.array([0.0, 0.0], dtype=float)), + "thomas": (thomas_rhs, np.array([0.1, 0.2, 0.3], dtype=float)), + "sprott_a": (sprott_a_rhs, np.array([0.1, 0.1, 0.1], dtype=float)), + "myspring": (myspring_rhs, np.array([0.4, 1.6], dtype=float)), +} + + +def rk4( + f: Callable, x0: np.ndarray, t0: float, t1: float, dt: float +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Integrate an ODE x' = f(t, x) using classical RK4. + + Parameters + - f: callable(t, x) -> np.ndarray, RHS function. + - x0: initial state vector. + - t0: start time. + - t1: end time. + - dt: time step. + + Returns + - t: time array of shape (N,). + - X: state trajectory of shape (N, d). + - dX_dt: analytic RHS evaluated along trajectory, shape (N, d). + """ + t = np.linspace(t0, t1, int(round((t1 - t0) / dt)) + 1) + X = np.zeros((t.size, x0.size), dtype=float) + X[0, :] = np.array(x0, dtype=float) + for i in range(1, t.size): + ti = t[i - 1] + xi = X[i - 1, :] + k1 = f(ti, xi) + k2 = f(ti + dt / 2.0, xi + dt * k1 / 2.0) + k3 = f(ti + dt / 2.0, xi + dt * k2 / 2.0) + k4 = f(ti + dt, xi + dt * k3) + X[i, :] = xi + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) + dX_dt = np.vstack([f(tt, xx) for tt, xx in zip(t, X)]) + return t, X, dX_dt + + +def build_library(degree: int = 3) -> PolynomialLibrary: + """Construct a polynomial feature library. + + Parameters + - degree: maximum polynomial degree; interactions enabled. + + Returns + - PolynomialLibrary instance. + """ + return PolynomialLibrary(degree=degree, include_interaction=True) + + +def run_optimizer( + name: str, + optimizer, + X: np.ndarray, + dt: float, + library: PolynomialLibrary, + dX_dt: np.ndarray, +): + """Fit a SINDy model and compute metrics for an optimizer. + + Returns a dict with name, runtime, score, MSE vs analytic RHS, complexity, + equations (if available), and the fitted model. + """ + model = SINDy(optimizer=optimizer, feature_library=library) + t0 = time.perf_counter() + model.fit(X, t=dt) + fit_time = time.perf_counter() - t0 + score = model.score(X, t=dt) + dXdt_pred = model.predict(X) + mse = float(np.mean((dXdt_pred - dX_dt) ** 2)) + complexity = optimizer.complexity if hasattr(optimizer, "complexity") else None + try: + equations = model.equations() + except Exception: + equations = None + return { + "name": name, + "fit_time_s": fit_time, + "score": float(score), + "mse": mse, + "complexity": int(complexity) if complexity is not None else None, + "equations": equations, + "model": model, + } + + +def main(): + """Entry point for the generalized benchmark runner. + + Parses CLI, builds the library, iterates over selected systems, runs selected + optimizers, prints a summary table, and highlights the best optimizer per system + (lowest MSE), including its discovered equations. + """ + warnings.filterwarnings("ignore") + parser = argparse.ArgumentParser( + description="Generalized nonlinear systems benchmark for PySINDy." + ) + parser.add_argument( + "--system", + type=str, + default="all", + choices=sorted(list(SYSTEMS.keys()) + ["all"]), + ) + parser.add_argument("--t0", type=float, default=0.0) + parser.add_argument("--t1", type=float, default=10.0) + parser.add_argument("--dt", type=float, default=0.01) + parser.add_argument( + "--degree", type=int, default=3, help="Polynomial library degree" + ) + parser.add_argument( + "--optimizers", type=str, default="all", help="Comma-separated list or 'all'" + ) + parser.add_argument( + "--list", action="store_true", help="List available systems and exit" + ) + args = parser.parse_args() + + if args.list: + print("Available systems:") + for name in sorted(SYSTEMS.keys()): + print(f" - {name}") + print(" - all (run every system)") + return + + systems_to_run = ( + [(args.system, SYSTEMS[args.system])] + if args.system != "all" + else list(SYSTEMS.items()) + ) + library = build_library(args.degree) + + # Select optimizers + opt_defs: List[Tuple[str, object]] = [] + opt_defs.append(("STLSQ", STLSQ())) + opt_defs.append( + ( + "SR3-L0", + SR3(), + ) + ) + opt_defs.append( + ( + "SR3-constrained", + ConstrainedSR3(), + ) + ) + opt_defs.append( + ( + "SR3-stable", + StableLinearSR3(), + ) + ) + opt_defs.append(("FROLS", FROLS())) + opt_defs.append(("SSR", SSR())) + if SBR is not None: + try: + opt_defs.append(("SBR", SBR())) + except Exception: + traceback.print_exc() + if TorchOptimizer is not None: + try: + opt_defs.append( + ( + "TorchOptimizer", + TorchOptimizer( + threshold=0.05, + alpha_l1=1e-3, + step_size=1e-2, + max_iter=200, + optimizer="cadamw", + seed=0, + unbias=True, + early_stopping_patience=50, + min_delta=1e-8, + ), + ) + ) + except Exception: + traceback.print_exc() + if JaxOptimizer is not None: + try: + opt_defs.append( + ( + "JaxOptimizer", + JaxOptimizer( + threshold=0.05, + alpha_l1=1e-3, + step_size=1e-2, + max_iter=200, + optimizer="adam", + seed=0, + unbias=True, + early_stopping_patience=50, + min_delta=1e-8, + ), + ) + ) + except Exception: + traceback.print_exc() + + if args.optimizers != "all": + names = {n.strip() for n in args.optimizers.split(",")} + opt_defs = [pair for pair in opt_defs if pair[0] in names] + + # Run per system + for sys_name, (rhs, x0) in systems_to_run: + print(f"\n=== System: {sys_name} ===") + t, X, dX_dt = rk4(rhs, x0, args.t0, args.t1, args.dt) + results = [] + for name, opt in opt_defs: + try: + res = run_optimizer(name, opt, X, args.dt, library, dX_dt) + except Exception as e: + res = {"name": name, "error": str(e)} + results.append(res) + + header = f"{'Optimizer':<15} {'Score':>10} {'MSE':>12} {'Time (s)':>12} {'Complexity':>12}" + print(header) + print("-" * len(header)) + for r in results: + if "error" in r: + print(f"{r['name']:<15} ERROR: {r['error']}") + else: + print( + f"{r['name']:<15} {r['score']:>10.4f} {r['mse']:>12.4e} {r['fit_time_s']:>12.4f} {str(r['complexity']):>12}" + ) + + # Select and print best optimizer by lowest MSE among successful runs + successful = [r for r in results if "error" not in r] + if successful: + best = min(successful, key=lambda r: r["mse"]) # type: ignore + print( + f"\n>>> Best optimizer: {best['name']} | Score={best['score']:.4f} | MSE={best['mse']:.4e} | Time={best['fit_time_s']:.4f}s | Complexity={best['complexity']}" + ) + eqs = best.get("equations") + if eqs: + print("Discovered equations:") + for eq in eqs: + print(f" {eq}") + else: + print("(equations unavailable)") + else: + print("\n>>> No successful optimizer runs for this system.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 4ff50ec1..97989bb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ dev = [ "jupytext", "pre-commit", "hypothesis", - "jupyter-contrib-nbextensions" + "jupyter-contrib-nbextensions", + "torch", ] docs = [ "ipython", diff --git a/pysindy/optimizers/__init__.py b/pysindy/optimizers/__init__.py index d4346b4f..f7e2af64 100644 --- a/pysindy/optimizers/__init__.py +++ b/pysindy/optimizers/__init__.py @@ -30,6 +30,14 @@ from .sbr import SBR except (ImportError, NameError): pass +try: + from .torch_solver import TorchOptimizer +except Exception: + pass +try: + from .jax_solver import JaxOptimizer +except Exception: + pass __all__ = [ @@ -46,4 +54,6 @@ "SINDyPI", "MIOSR", "SBR", + "TorchOptimizer", + "JaxOptimizer", ] diff --git a/pysindy/optimizers/jax_solver.py b/pysindy/optimizers/jax_solver.py new file mode 100644 index 00000000..29a57cc5 --- /dev/null +++ b/pysindy/optimizers/jax_solver.py @@ -0,0 +1,294 @@ +""" +JAX-based SINDy optimizer using proximal gradient + iterative hard-thresholding. + +This module provides a high-performance optimizer implemented with JAX to solve +Sparse Identification of Nonlinear Dynamics (SINDy) regression problems. It +minimizes a smooth data-fit term and applies proximal/thresholding operations to +promote sparsity in the discovered dynamical system. + +Key features +- Batched multi-target regression on CPU or GPU/TPU (if available). +- Proximal L1 shrinkage and hard thresholding to encourage sparse models. +- Optimizers: SGD, Adam, AdamW with minimal implementations in JAX. +- Note: The CAdamW variant is currently not available in this implementation. +- Best-solution tracking across iterations and early-stopping support. +- Compatible with PySINDy BaseOptimizer interface and ensembling. + +Optional dependencies +- JAX is optional at import time; an ImportError will be raised during fit if + jax is not available. Code paths and annotations avoid import-time failures. + +Notes +----- +- Thresholding and proximal operations operate on coefficient magnitudes. A small + numerical threshold (1e-14) is used to derive support masks for `ind_`. +- When `sparse_ind` is provided, thresholding affects only the specified columns. +- Early stopping halts iterations when the objective fails to improve by at least + `min_delta` for `early_stopping_patience` consecutive steps. +- The optimizer tracks and restores the best solution observed across iterations. +""" +import warnings +from typing import List +from typing import Optional + +import numpy as np + +from .base import BaseOptimizer + +try: + import jax # type: ignore + import jax.numpy as jnp # type: ignore +except Exception: # pragma: no cover - optional dependency + jax = None # type: ignore + jnp = None # type: ignore + + +def _soft_threshold(t: jnp.ndarray, lam: float): + if lam <= 0: + return t + return jnp.sign(t) * jnp.maximum(jnp.abs(t) - lam, 0.0) + + +def _hard_threshold(t: jnp.ndarray, thr: float): + if thr <= 0: + return t + return t * (jnp.abs(t) >= thr) + + +class JaxOptimizer(BaseOptimizer): + """JAX-powered optimizer for sparse SINDy regression. + + Objective + J(W) = (1/N) * ||Y - X W^T||_F^2 + alpha_l1 * ||W||_1 + + Parameters + ---------- + threshold : float, default 1e-1 + Minimum magnitude for a coefficient. Values with |coef| < threshold + are set to zero after each iteration (hard-threshold). + alpha_l1 : float, default 0.0 + L1 penalty weight. If > 0, soft-thresholding is applied after the + gradient step to shrink coefficients. + step_size : float, default 1e-1 + Learning rate for the chosen Torch optimizer. + max_iter : int, default 1000 + Maximum number of iterations. + optimizer : {"sgd", "adam", "adamw", "cadamw"}, default "adam" + Which optimizer to use for the smooth part of the objective. + normalize_columns : bool, default False + See BaseOptimizer; if True, columns of X are normalized before fitting. + copy_X : bool, default True + See BaseOptimizer; controls whether X is copied or may be overwritten. + initial_guess : np.ndarray or None, default None + Warm-start coefficients; shape (n_targets, n_features). + verbose : bool, default False + If True, prints periodic progress including loss and sparsity. + device : {"cpu", "cuda"} or None, default None + Torch device to use. If None, uses CPU; if "cuda" is requested but not + available, falls back to CPU with a warning. + seed : int or None, default None + Random seed for reproducibility of Torch and NumPy. + sparse_ind : list[int] or None, default None + If provided, thresholding only applies to these feature indices. Other + indices remain unaffected by hard-thresholding. + unbias : bool, default True + See BaseOptimizer; when True, performs an unbiased refit on the selected + support after optimization. + early_stopping_patience : int, default 0 + If > 0, stop early when the objective has not improved by at least + `min_delta` for this many consecutive iterations. + min_delta : float, default 0.0 + Minimum improvement to reset patience; small positive values help + prevent stopping on floating-point noise. + Attributes + ---------- + coef_ : np.ndarray, shape (n_targets, n_features) + Optimized SINDy coefficients. + history_ : list of np.ndarray + Coefficient history; `history_[k]` is the coefficient matrix after + iteration k. + ind_ : np.ndarray, shape (n_targets, n_features) + Boolean mask of nonzero coefficients (|coef| > 1e-14). + Examples + -------- + >>> import numpy as np + >>> from scipy.integrate import odeint + >>> from pysindy import SINDy + >>> from pysindy.optimizers import JaxOptimizer + >>> lorenz = lambda z,t : [10 * (z[1] - z[0]), + >>> z[0] * (28 - z[2]) - z[1], + >>> z[0] * z[1] - 8 / 3 * z[2]] + >>> t = np.arange(0, 2, .002) + >>> x = odeint(lorenz, [-8, 8, 27], t) + >>> opt = JaxOptimizer(threshold=.1, alpha_l1=.01, max_iter=1000) + >>> model = SINDy(optimizer=opt) + >>> model.fit(x, t=t[1] - t[0]) + >>> model.print() + + x0' = -9.973 x0 + 9.973 x1 + x1' = -0.129 1 + 27.739 x0 + -0.949 x1 + -0.993 x0 x2 + x2' = -2.656 x2 + 0.996 x0 x1 + """ + + def __init__( + self, + threshold: float = 1e-1, + alpha_l1: float = 0.0, + step_size: float = 1e-1, + max_iter: int = 1000, + optimizer: str = "adam", + normalize_columns: bool = False, + copy_X: bool = True, + initial_guess: Optional[np.ndarray] = None, + verbose: bool = False, + seed: Optional[int] = None, + sparse_ind: Optional[List[int]] = None, + unbias: bool = True, + early_stopping_patience: int = 100, + min_delta: float = 1e-10, + ): + super().__init__( + max_iter=max_iter, + normalize_columns=normalize_columns, + initial_guess=initial_guess, + copy_X=copy_X, + unbias=unbias, + ) + if threshold < 0: + raise ValueError("threshold cannot be negative") + if alpha_l1 < 0: + raise ValueError("alpha_l1 cannot be negative") + if step_size <= 0: + raise ValueError("step_size must be positive") + if optimizer not in ("sgd", "adam", "adamw"): + raise ValueError("optimizer must be 'sgd', 'adam', or 'adamw'") + if early_stopping_patience < 0: + raise ValueError("early_stopping_patience cannot be negative") + if min_delta < 0: + raise ValueError("min_delta cannot be negative") + self.threshold = float(threshold) + self.alpha_l1 = float(alpha_l1) + self.step_size = float(step_size) + self.verbose = bool(verbose) + self.seed = seed + self.opt_name = optimizer + self.sparse_ind = sparse_ind + self.early_stopping_patience = int(early_stopping_patience) + self.min_delta = float(min_delta) + self.stability_eps = 1e-14 + + if jax is None: + warnings.warn( + "JAX is not installed; " + "JaxOptimizer will not run until jax is available." + ) + + def _reduce(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None: + if jax is None: + raise ImportError("JAX is required for JaxOptimizer. Please install jax.") + + # Seed control + if self.seed is not None: + # jax uses PRNG keys. + # We just seed numpy for determinism in thresholds/history + np.random.seed(self.seed) + + X = jnp.asarray(x, dtype=jnp.float64) + Y = jnp.asarray(y, dtype=jnp.float64) + n_samples, n_features = X.shape + n_targets = Y.shape[1] + + if self.coef_ is None: + W = jnp.zeros((n_targets, n_features), dtype=jnp.float64) + else: + W = jnp.asarray(self.coef_, dtype=jnp.float64) + + # sparse mask + sparse_mask = None + if self.sparse_ind is not None: + sparse_mask = jnp.zeros((n_targets, n_features), dtype=bool) + sparse_mask = sparse_mask.at[:, self.sparse_ind].set(True) + + def loss_fn(W_): + Y_pred = X @ W_.T + residual = Y_pred - Y + mse = jnp.sum(residual**2) / n_samples + l1 = self.alpha_l1 * jnp.sum(jnp.abs(W_)) if self.alpha_l1 > 0 else 0.0 + return mse + l1 + + grad_fn = jax.grad(lambda W_: loss_fn(W_)) + + # Optimizer states for Adam/AdamW + m = jnp.zeros_like(W) + v = jnp.zeros_like(W) + beta1, beta2, eps, wd = 0.9, 0.999, 1e-8, 0.0 + if self.opt_name == "adamw": + wd = 1e-4 # small default weight decay + + last_mask = None + best_obj = None + best_W = None + patience_counter = 0 + + def step_adam(W, m, v, g, t): + m = beta1 * m + (1 - beta1) * g + v = beta2 * v + (1 - beta2) * (g * g) + m_hat = m / (1 - beta1**t) + v_hat = v / (1 - beta2**t) + W = W - self.step_size * m_hat / (jnp.sqrt(v_hat) + eps) + if wd > 0: + W = W - self.step_size * wd * W + return W, m, v + + # Loop + for it in range(self.max_iter): + g = grad_fn(W) + if self.opt_name == "sgd": + W = W - self.step_size * g + else: + W, m, v = step_adam(W, m, v, g, it + 1) + + # proximal L1 + if self.alpha_l1 > 0: + W = _soft_threshold(W, self.alpha_l1 * self.step_size) + # hard threshold + if self.threshold > 0: + kept = _hard_threshold(W, self.threshold) + if sparse_mask is None: + W = kept + else: + W = jnp.where(sparse_mask, kept, W) + + # evaluate objective + obj = float(loss_fn(W)) + if best_obj is None or obj < (best_obj - self.min_delta): + best_obj = obj + best_W = W + patience_counter = 0 + else: + patience_counter += 1 + + # track history and early stop + coef_np = np.array(W) + self.history_.append(coef_np) + mask = np.abs(coef_np) >= max(self.threshold, self.stability_eps) + if last_mask is not None and np.array_equal(mask, last_mask): + break + last_mask = mask + if 0 < self.early_stopping_patience <= patience_counter: + break + + if self.verbose and ( + it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1 + ): + mse_val = float(np.sum((np.array(X @ W.T - Y)) ** 2) / n_samples) + l0 = int(np.sum(np.abs(coef_np) >= self.threshold)) + print(f"[JaxSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") + + final_W = np.array(best_W if best_W is not None else W) + self.coef_ = final_W + self.ind_ = np.abs(self.coef_) > self.stability_eps + + @property + def complexity(self): + return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) diff --git a/pysindy/optimizers/torch_solver.py b/pysindy/optimizers/torch_solver.py new file mode 100644 index 00000000..e5177abb --- /dev/null +++ b/pysindy/optimizers/torch_solver.py @@ -0,0 +1,511 @@ +""" +PyTorch-based SINDy optimizer using proximal gradient + iterative hard-thresholding. + +This module provides a high-performance optimizer implemented with PyTorch to solve +Sparse Identification of Nonlinear Dynamics (SINDy) regression problems. It +minimizes a smooth data-fit term and applies proximal/thresholding operations to +promote sparsity in the discovered dynamical system. + +Key features +- Batched multi-target regression on CPU or GPU (if available). +- Proximal L1 shrinkage and hard thresholding to encourage sparse models. +- Optional optimizers: SGD, Adam, AdamW, and custom CAdamW. +- Best-solution tracking across iterations and early-stopping support. +- Compatible with PySINDy BaseOptimizer interface and ensembling. + +Optional dependencies +- PyTorch is optional at import time; an ImportError will be raised during fit if + torch is not available. Code paths and annotations avoid import-time failures. +- CAdamW is a custom optimizer shipped with the project and used when requested. + +Usage +----- +Example: Fit a SINDy model with the Torch optimizer. + + >>> import numpy as np + >>> from pysindy import SINDy + >>> from pysindy.feature_library import PolynomialLibrary + >>> from pysindy.optimizers.torch_solver import TorchOptimizer + >>> rng = np.random.default_rng(0) + >>> X = rng.standard_normal((500, 3)) + >>> Y = X @ np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]).T + >>> opt = TorchOptimizer() + >>> lib = PolynomialLibrary(degree=2) + >>> model = SINDy(optimizer=opt, feature_library=lib) + >>> model.fit(X, t=0.01) + >>> model.equations() # doctest: +ELLIPSIS + [...] + +Notes +----- +- Thresholding and proximal operations operate on coefficient magnitudes. A small + numerical threshold (1e-14) is used to derive support masks for `ind_`. +- When `sparse_ind` is provided, thresholding affects only the specified columns. +- Early stopping halts iterations when the objective fails to improve by at least + `min_delta` for `early_stopping_patience` consecutive steps. +- The optimizer tracks and restores the best solution observed across iterations. +""" +import warnings +from typing import List +from typing import Optional + +import numpy as np + +from .base import BaseOptimizer + +try: + import torch # type: ignore +except Exception: # pragma: no cover - optional dependency + torch = None # type: ignore + +import math +from typing import Callable, Iterable, Tuple + +from torch import nn +from torch.optim import Optimizer + + +# Taken from https://github.com/kyleliang919/C-Optim/blob/main/c_adamw.py +class CAdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix + as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries + defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam + (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning + (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self.init_lr = lr + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): + A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for i, p in enumerate(group["params"]): + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # apply weight decay + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + # compute norm gradient + mask = (exp_avg * grad > 0).to(grad.dtype) + # mask = mask * (mask.numel() / (mask.sum() + 1)) + # ## original implementation, leaving it here for record + mask.div_( + mask.mean().clamp_(min=1e-3) + ) # https://huggingface.co/rwightman/timm-optim-caution + # found this implementation is more favourable in many cases + norm_grad = (exp_avg * mask) / denom + p.add_(norm_grad, alpha=-step_size) + return loss + + +def _soft_threshold(t, lam: float): + """Soft-thresholding proximal operator. + + Applies element-wise soft-shrinkage: sign(t) * max(|t| - lam, 0). + + Parameters + - t: torch.Tensor with coefficients. + - lam: float threshold (lambda) controlling the shrinkage amount. + + Returns + - torch.Tensor of same shape as `t` after soft-thresholding. + """ + if lam <= 0: + return t + return torch.sign(t) * torch.clamp(torch.abs(t) - lam, min=0.0) + + +def _hard_threshold(t, thr: float): + """Hard-thresholding operator. + + Zeros out elements whose absolute value is below the given threshold. + + Parameters + - t: torch.Tensor with coefficients. + - thr: float magnitude threshold. + + Returns + - torch.Tensor of same shape as `t` with small entries set to zero. + """ + if thr <= 0: + return t + return t * (torch.abs(t) >= thr) + + +class TorchOptimizer(BaseOptimizer): + """Torch-powered optimizer for sparse SINDy regression. + + This optimizer minimizes the objective + + J(W) = (1/N) * ||Y - X W^T||_F^2 + alpha_l1 * ||W||_1 + + using gradient-based updates (SGD/Adam/AdamW/CAdamW), followed by proximal + soft-thresholding (L1) and hard-thresholding to encourage sparsity. It supports + multi-target regression (rows correspond to targets) and optional GPU acceleration. + + Parameters + ---------- + threshold : float, default 1e-1 + Minimum magnitude for a coefficient. Values with |coef| < threshold + are set to zero after each iteration (hard-threshold). + alpha_l1 : float, default 0.0 + L1 penalty weight. If > 0, soft-thresholding is applied after the + gradient step to shrink coefficients. + step_size : float, default 1e-1 + Learning rate for the chosen Torch optimizer. + max_iter : int, default 1000 + Maximum number of iterations. + optimizer : {"sgd", "adam", "adamw", "cadamw"}, default "adam" + Which optimizer to use for the smooth part of the objective. + normalize_columns : bool, default False + See BaseOptimizer; if True, columns of X are normalized before fitting. + copy_X : bool, default True + See BaseOptimizer; controls whether X is copied or may be overwritten. + initial_guess : np.ndarray or None, default None + Warm-start coefficients; shape (n_targets, n_features). + verbose : bool, default False + If True, prints periodic progress including loss and sparsity. + device : {"cpu", "cuda"} or None, default None + Torch device to use. If None, uses CPU; if "cuda" is requested but not + available, falls back to CPU with a warning. + seed : int or None, default None + Random seed for reproducibility of Torch and NumPy. + sparse_ind : list[int] or None, default None + If provided, thresholding only applies to these feature indices. Other + indices remain unaffected by hard-thresholding. + unbias : bool, default True + See BaseOptimizer; when True, performs an unbiased refit on the selected + support after optimization. + early_stopping_patience : int, default 0 + If > 0, stop early when the objective has not improved by at least + `min_delta` for this many consecutive iterations. + min_delta : float, default 0.0 + Minimum improvement to reset patience; small positive values help + prevent stopping on floating-point noise. + + Attributes + ---------- + coef_ : np.ndarray, shape (n_targets, n_features) + Final coefficients (best observed if early stopping is enabled). + ind_ : np.ndarray[bool], shape (n_targets, n_features) + Support mask where True indicates non-zero (above tiny threshold). + history_ : list[np.ndarray] + Sequence of coefficient matrices recorded at each iteration. + intercept_ : float or np.ndarray + Intercept term (handled by BaseOptimizer; not fit here). + + Notes + ----- + - Best-solution tracking: the optimizer records the coefficient state with + the lowest objective value and restores it at the end. + - Early-stopping: controlled by `early_stopping_patience` and `min_delta`. + - The objective combines mean-squared error and optional L1 penalty. + - For reproducible results, set `seed` and avoid non-deterministic CUDA ops. + """ + + def __init__( + self, + threshold: float = 1e-1, + alpha_l1: float = 0.0, + step_size: float = 1e-1, + max_iter: int = 1000, + optimizer: str = "cadamw", + normalize_columns: bool = False, + copy_X: bool = True, + initial_guess: Optional[np.ndarray] = None, + verbose: bool = False, + device: Optional[str] = None, + seed: Optional[int] = None, + sparse_ind: Optional[List[int]] = None, + unbias: bool = True, + early_stopping_patience: int = 100, + min_delta: float = 1e-10, + ): + super().__init__( + max_iter=max_iter, + normalize_columns=normalize_columns, + initial_guess=initial_guess, + copy_X=copy_X, + unbias=unbias, + ) + if threshold < 0: + raise ValueError("threshold cannot be negative") + if alpha_l1 < 0: + raise ValueError("alpha_l1 cannot be negative") + if step_size <= 0: + raise ValueError("step_size must be positive") + if optimizer not in ("sgd", "adam", "adamw", "cadamw"): + raise ValueError("optimizer must be 'sgd', 'adam', 'adamw', or 'cadamw'") + if early_stopping_patience < 0: + raise ValueError("early_stopping_patience cannot be negative") + if min_delta < 0: + raise ValueError("min_delta cannot be negative") + self.threshold = float(threshold) + self.alpha_l1 = float(alpha_l1) + self.step_size = float(step_size) + self.verbose = bool(verbose) + self.torch_device = device or "cpu" + self.seed = seed + self.opt_name = optimizer + self.sparse_ind = sparse_ind + self.early_stopping_patience = int(early_stopping_patience) + self.min_delta = float(min_delta) + self.stability_eps = 1e-14 + + if torch is None: + # Delay hard failure to fit-time to + # allow import of module without torch + warnings.warn( + "PyTorch is not installed; " + "TorchOptimizer will not run until torch is available." + ) + + def _reduce(self, x: np.ndarray, y: np.ndarray) -> None: + """Core optimization loop. + + This method performs up to `max_iter` iterations of gradient-based updates + on the smooth objective, followed by proximal soft-thresholding and hard + thresholding to enforce sparsity. It maintains a history of coefficients, + tracks the best objective value, and optionally stops early. + + Parameters + ---------- + x : np.ndarray, shape (n_samples, n_features) + Feature matrix (SINDy library output). Normalization may be applied + by the BaseOptimizer depending on constructor options. + y : np.ndarray, shape (n_samples, n_targets) + Target matrix (derivatives). + + Side effects + ------------ + - Updates `self.coef_` with the best observed coefficients. + - Updates `self.ind_` as a boolean support mask. + - Appends snapshots to `self.history_` each iteration. + + Raises + ------ + ImportError + If PyTorch is not installed at run time. + """ + if torch is None: + raise ImportError( + "PyTorch is required for TorchOptimizer. Please install torch." + ) + # Select device + if self.torch_device == "cuda" and not torch.cuda.is_available(): + warnings.warn("CUDA not available; falling back to CPU.") + device = torch.device("cpu") + else: + device = torch.device(self.torch_device) + # Seed control + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + # Data to torch + dtype = torch.float64 # match numpy default precision + X = torch.as_tensor(x, dtype=dtype, device=device) + Y = torch.as_tensor(y, dtype=dtype, device=device) + n_samples, n_features = X.shape + n_targets = Y.shape[1] + + # Parameter tensor W: shape (n_targets, n_features) + if self.coef_ is None: + W = torch.zeros((n_targets, n_features), dtype=dtype, device=device) + else: + # Decouple from numpy memory to avoid mutating BaseOptimizer.history_[0] + W = torch.tensor(self.coef_, dtype=dtype, device=device).clone() + W.requires_grad_(True) + + # Optimizer for smooth loss + if self.opt_name == "adam": + opt = torch.optim.Adam([W], lr=self.step_size) + elif self.opt_name == "adamw": + opt = torch.optim.AdamW([W], lr=self.step_size) + elif self.opt_name == "cadamw": + opt = CAdamW([W], lr=self.step_size) + else: + opt = torch.optim.SGD([W], lr=self.step_size, momentum=0.9) + + # Support mask helper: restrict thresholding to specified indices + sparse_mask = None + if self.sparse_ind is not None: + sparse_mask = torch.zeros( + (n_targets, n_features), dtype=torch.bool, device=device + ) + sparse_mask[:, self.sparse_ind] = True + + def loss_fn(W_): + """Compute objective: MSE/N + alpha_l1 * ||W||_1.""" + Y_pred = X @ W_.T # (n_samples, n_targets) + residual = Y_pred - Y + mse = (residual.pow(2)).sum() / n_samples + if self.alpha_l1 > 0: + l1 = self.alpha_l1 * torch.abs(W_).sum() + else: + l1 = torch.zeros((), dtype=dtype, device=device) + return mse + l1 + + last_mask = None + best_obj = None + best_W = None + patience_counter = 0 + + # Simple gradient stepping just as you would do in learning a model on pytorch + for it in range(self.max_iter): + opt.zero_grad(set_to_none=True) + loss = loss_fn(W) + loss.backward() + # Gradient step via optimizer + opt.step() + # Prox for L1 + if self.alpha_l1 > 0: + with torch.no_grad(): + W[:] = _soft_threshold(W, self.alpha_l1 * self.step_size) + # Hard-threshold + with torch.no_grad(): + if self.threshold > 0: + if sparse_mask is None: + W[:] = _hard_threshold(W, self.threshold) + else: + kept = _hard_threshold(W, self.threshold) + W[:] = torch.where(sparse_mask, kept, W) + # Evaluate objective and update best + with torch.no_grad(): + obj = float(loss_fn(W).cpu().numpy()) + if best_obj is None or obj < (best_obj - self.min_delta): + best_obj = obj + best_W = W.detach().clone() + patience_counter = 0 + else: + patience_counter += 1 + # Track history and early stop if support unchanged or patience exceeded + with torch.no_grad(): + coef_np = W.detach().cpu().numpy() + self.history_.append(coef_np) + mask = np.abs(coef_np) >= max(self.threshold, self.stability_eps) + if last_mask is not None and np.array_equal(mask, last_mask): + break + last_mask = mask + if 0 < self.early_stopping_patience <= patience_counter: + break + if self.verbose and ( + it % max(1, self.max_iter // 10) == 0 or it == self.max_iter - 1 + ): + mse_val = ( + float(((X @ W.T - Y).pow(2)).sum().cpu().detach().numpy()) + / n_samples + ) + l0 = int((torch.abs(W) >= self.threshold).sum().item()) + print(f"[TorchSINDy] iter={it} mse={mse_val:.4e} L0={l0} obj={obj:.4e}") + + # Final coefficients back to numpy: use best if available + final_W = (best_W if best_W is not None else W).detach().cpu().numpy() + self.coef_ = final_W + # ind_ based on tiny threshold + self.ind_ = np.abs(self.coef_) > self.stability_eps + + @property + def complexity(self): + """Model complexity measure. + + Returns the number of non-zero coefficients plus the number of non-zero + intercepts (if any). This should allow comparing sparsity across optimizers. + + Returns + ------- + int + Complexity score: count_nonzero(coef_) + count_nonzero(intercept_). + """ + return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_) diff --git a/test/test_optimizers/test_jax_optimizer.py b/test/test_optimizers/test_jax_optimizer.py new file mode 100644 index 00000000..88ff8edc --- /dev/null +++ b/test/test_optimizers/test_jax_optimizer.py @@ -0,0 +1,131 @@ +import numpy as np +import pytest + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary + +jax = pytest.importorskip("jax") + +jax_available = True +try: + from pysindy.optimizers import JaxOptimizer +except Exception: + jax_available = False + + +def make_synthetic(n_samples=200, noise=0.0, seed=0): + rng = np.random.default_rng(seed) + X = rng.normal(size=(n_samples, 3)) + W = np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]) + Y = X @ W.T + noise * rng.normal(size=(n_samples, 3)) + return X, Y + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_basic_fit_shapes(): + X, Y = make_synthetic() + opt = JaxOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=1) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + assert opt.ind_.shape == (Y.shape[1], X.shape[1]) + assert len(opt.history_) >= 1 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_unbias_and_sparsity(): + X, Y = make_synthetic(noise=0.01) + opt = JaxOptimizer(max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True) + opt.fit(X, Y) + assert np.count_nonzero(opt.coef_) < opt.coef_.size + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_multi_target_with_sindylib(): + t = np.linspace(0, 1, 200) + x = np.stack( + [np.sin(2 * np.pi * t), np.cos(2 * np.pi * t), 0.5 * np.sin(4 * np.pi * t)], + axis=1, + ) + lib = PolynomialLibrary(degree=2) + opt = JaxOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=0) + model = SINDy(optimizer=opt, feature_library=lib) + model.fit(x, t=t[1] - t[0]) + score = model.score(x, t=t[1] - t[0]) + assert score > 0.8 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +@pytest.mark.parametrize("opt_name", ["sgd", "adam", "adamw"]) +def test_optimizer_variants_run(opt_name): + X, Y = make_synthetic(n_samples=100, noise=0.01, seed=3) + opt = JaxOptimizer( + optimizer=opt_name, max_iter=30, threshold=1e-3, alpha_l1=1e-4, seed=1 + ) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_early_stopping_via_patience_and_min_delta(): + X, Y = make_synthetic(n_samples=150, noise=0.01, seed=4) + patience = 2 + opt = JaxOptimizer( + max_iter=200, + threshold=0.0, + alpha_l1=0.0, + seed=0, + early_stopping_patience=patience, + min_delta=1e9, + ) + opt.fit(X, Y) + assert len(opt.history_) <= 1 + patience + 1 + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_complexity_property_matches_manual(): + X, Y = make_synthetic(n_samples=120, noise=0.02, seed=6) + opt = JaxOptimizer(max_iter=60, threshold=5e-2, alpha_l1=1e-3, seed=6) + opt.fit(X, Y) + manual = np.count_nonzero(opt.coef_) + np.count_nonzero(opt.intercept_) + assert opt.complexity == manual + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_verbose_prints_progress(capsys): + X, Y = make_synthetic(n_samples=60, noise=0.0, seed=7) + opt = JaxOptimizer(max_iter=5, threshold=1e-4, alpha_l1=0.0, verbose=True, seed=0) + opt.fit(X, Y) + captured = capsys.readouterr().out + assert "[JaxSINDy]" in captured and "iter=" in captured + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_history_tracking_shapes_and_length(): + X, Y = make_synthetic(n_samples=80, noise=0.0, seed=8) + init = np.zeros((Y.shape[1], X.shape[1])) + opt = JaxOptimizer( + max_iter=8, threshold=1e-4, alpha_l1=1e-4, seed=0, initial_guess=init + ) + opt.fit(X, Y) + assert isinstance(opt.history_, list) + assert len(opt.history_) >= 2 + shapes_ok = all(h.shape == opt.coef_.shape for h in opt.history_) + assert shapes_ok + assert not np.allclose(opt.history_[0], opt.history_[-1]) + + +@pytest.mark.skipif(not jax_available, reason="JAX not available") +def test_sparse_ind_hard_thresholding_effect(): + X, Y = make_synthetic(n_samples=120, noise=0.0, seed=42) + thr = 1e9 + opt = JaxOptimizer( + max_iter=1, + step_size=1e-3, + threshold=thr, + alpha_l1=0.0, + seed=0, + sparse_ind=[0], + ) + opt.fit(X, Y) + assert np.allclose(opt.coef_[:, 0], 0.0) + assert np.any(np.abs(opt.coef_[:, 1:]) > 0.0) diff --git a/test/test_optimizers/test_torch_optimizer.py b/test/test_optimizers/test_torch_optimizer.py new file mode 100644 index 00000000..f8a32c9a --- /dev/null +++ b/test/test_optimizers/test_torch_optimizer.py @@ -0,0 +1,151 @@ +import numpy as np +import pytest + +from pysindy import SINDy +from pysindy.feature_library import PolynomialLibrary +from pysindy.optimizers import TorchOptimizer + +torch = pytest.importorskip("torch") + + +def make_synthetic(n_samples=200, noise=0.0, seed=0): + rng = np.random.default_rng(seed) + X = rng.normal(size=(n_samples, 3)) + # True W + W = np.array([[1.0, 0.0, -0.5], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]) + Y = X @ W.T + noise * rng.normal(size=(n_samples, 3)) + return X, Y + + +def test_basic_fit_shapes(): + X, Y = make_synthetic() + opt = TorchOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=1) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + assert opt.ind_.shape == (Y.shape[1], X.shape[1]) + assert len(opt.history_) >= 1 + + +def test_unbias_and_sparsity(): + X, Y = make_synthetic(noise=0.01) + opt = TorchOptimizer( + max_iter=100, threshold=0.05, alpha_l1=1e-3, seed=2, unbias=True + ) + opt.fit(X, Y) + # Check some sparsity present + assert np.count_nonzero(opt.coef_) < opt.coef_.size + + +def test_multi_target_with_sindylib(): + # Integration with SINDy and a small library + t = np.linspace(0, 1, 200) + x = np.stack( + [np.sin(2 * np.pi * t), np.cos(2 * np.pi * t), 0.5 * np.sin(4 * np.pi * t)], + axis=1, + ) + lib = PolynomialLibrary(degree=2) + opt = TorchOptimizer(max_iter=50, threshold=1e-2, alpha_l1=1e-3, seed=0) + model = SINDy(optimizer=opt, feature_library=lib) + model.fit(x, t=t[1] - t[0]) + score = model.score(x, t=t[1] - t[0]) + assert score > 0.8 + + +# New tests for broader coverage + + +def test_sparse_ind_hard_thresholding_effect(): + # With extremely large threshold and one iteration, only columns listed in + # sparse_ind should be hard-thresholded to zero; others should retain small + # non-zero values from the single gradient step. + X, Y = make_synthetic(n_samples=120, noise=0.0, seed=42) + thr = 1e9 + opt = TorchOptimizer( + max_iter=1, + step_size=1e-3, + threshold=thr, + alpha_l1=0.0, + seed=0, + sparse_ind=[0], # only first feature is forced to zero by hard threshold + ) + opt.fit(X, Y) + # Column 0 should be exactly zero (hard-thresholded) + assert np.allclose(opt.coef_[:, 0], 0.0) + # At least one coefficient outside column 0 should remain non-zero + assert np.any(np.abs(opt.coef_[:, 1:]) > 0.0) + + +@pytest.mark.parametrize("opt_name", ["sgd", "adam", "adamw", "cadamw"]) +def test_optimizer_variants_run(opt_name): + X, Y = make_synthetic(n_samples=100, noise=0.01, seed=3) + opt = TorchOptimizer( + optimizer=opt_name, max_iter=30, threshold=1e-3, alpha_l1=1e-4, seed=1 + ) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + + +def test_early_stopping_via_patience_and_min_delta(): + # Use a huge min_delta so improvements never count; should stop after + # early_stopping_patience iterations instead of max_iter. + X, Y = make_synthetic(n_samples=150, noise=0.01, seed=4) + patience = 2 + opt = TorchOptimizer( + max_iter=200, + threshold=0.0, + alpha_l1=0.0, + seed=0, + early_stopping_patience=patience, + min_delta=1e9, + ) + opt.fit(X, Y) + # history_ includes the initial entry (from BaseOptimizer) + per-iteration appends + # Ensure we stopped well before max_iter due to patience + assert len(opt.history_) <= 1 + patience + 1 # a few iterations at most + + +def test_complexity_property_matches_manual(): + X, Y = make_synthetic(n_samples=120, noise=0.02, seed=6) + opt = TorchOptimizer(max_iter=60, threshold=5e-2, alpha_l1=1e-3, seed=6) + opt.fit(X, Y) + manual = np.count_nonzero(opt.coef_) + np.count_nonzero(opt.intercept_) + assert opt.complexity == manual + + +def test_verbose_prints_progress(capsys): + X, Y = make_synthetic(n_samples=60, noise=0.0, seed=7) + opt = TorchOptimizer(max_iter=5, threshold=1e-4, alpha_l1=0.0, verbose=True, seed=0) + opt.fit(X, Y) + captured = capsys.readouterr().out + assert "[TorchSINDy]" in captured and "iter=" in captured + + +def test_history_tracking_shapes_and_length(): + X, Y = make_synthetic(n_samples=80, noise=0.0, seed=8) + # Provide a zero initial guess so the first history entry (initial) certainly + # differs from later iterates after optimization updates. + init = np.zeros((Y.shape[1], X.shape[1])) + opt = TorchOptimizer( + max_iter=8, threshold=1e-4, alpha_l1=1e-4, seed=0, initial_guess=init + ) + opt.fit(X, Y) + # History should contain the initial coefficients + at least one update + assert isinstance(opt.history_, list) + assert len(opt.history_) >= 2 + shapes_ok = all(h.shape == opt.coef_.shape for h in opt.history_) + assert shapes_ok + # The first history entry (initial guess) should differ from a later iterate + assert not np.allclose(opt.history_[0], opt.history_[-1]) + + +def test_device_selection_warning_or_success(): + X, Y = make_synthetic(n_samples=50, noise=0.0, seed=5) + if not torch.cuda.is_available(): + with pytest.warns(UserWarning, match="CUDA not available; falling back to CPU"): + opt = TorchOptimizer(device="cuda", max_iter=5, threshold=0.0, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1]) + else: + opt = TorchOptimizer(device="cuda", max_iter=10, threshold=1e-4, seed=0) + opt.fit(X, Y) + assert opt.coef_.shape == (Y.shape[1], X.shape[1])