Skip to content

Conversation

@hvoss-techfak
Copy link

Dear Dynamicslab,

Love the project that you are doing.
I was learning a bit more about ODEs and PDEs and thought that maybe a solver based on proximal gradients and iterative hard thresholding could be useful for more complex problems.
This code introduces a solver based on a torch gradient descent algorithm and also a benchmark file.

The solver uses an adapted version of the cAdamW optimizer, as it performed slightly better than Adam or AdamW in my experiments.

The benchmark file runs multiple problems on all available solvers to check which one performs best given the problem statement. The output looks something like this:

System: lorenz
Optimizer Score MSE Time (s) Complexity
STLSQ 1.0000 3.1003e-02 0.0526 9
SR3-L0 0.9959 7.5510e+00 0.0137 7
FROLS 1.0000 6.5393e-02 0.0889 30
SSR 0.9993 1.3134e+00 0.0480 6
TorchOptimizer 1.0000 3.0785e-02 1.3960 8

Best optimizer: TorchOptimizer | Score=1.0000 | MSE=3.0785e-02 | Time=1.3960s | Complexity=8
Discovered equations:
-9.973 x0 + 9.973 x1
-0.129 1 + 27.739 x0 + -0.949 x1 + -0.993 x0 x2
-2.656 x2 + 0.996 x0 x1

Copilot AI review requested due to automatic review settings December 5, 2025 11:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a PyTorch-based optimizer (TorchOptimizer) for sparse system identification using proximal gradient descent and iterative hard thresholding, along with a comprehensive benchmark script to compare optimizers across multiple nonlinear dynamical systems.

Key changes:

  • New TorchOptimizer class implementing gradient-based sparse regression with support for SGD, Adam, AdamW, and custom CAdamW optimizers
  • Benchmark script evaluating optimizers on 12 different ODE systems (Lorenz, Rössler, Van der Pol, etc.)
  • Test suite for the new optimizer with basic integration tests

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
pysindy/optimizers/torch_solver.py Core implementation of the PyTorch-based optimizer with proximal gradient methods and iterative thresholding
test/test_optimizers/test_torch_optimizer.py Basic test coverage for TorchOptimizer including shape validation, sparsity, and SINDy integration
pysindy/optimizers/init.py Registers TorchOptimizer as an optional dependency with conditional import
pyproject.toml Adds torch to development dependencies
examples/benchmarks/benchmark.py Comprehensive benchmark runner comparing multiple optimizers across various dynamical systems
Comments suppressed due to low confidence (1)

examples/benchmarks/benchmark.py:346

  • 'except' clause does nothing but pass and there is no explanatory comment.
        except Exception:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Jacob-Stevens-Haas
Copy link
Member

Whoa, I don't know why copilot auto reviewed the PR. This is the first time that's happened. Let me look into that. I'm sorry you spun your gears on copilot's comments before I had a chance to talk about larger aspects.

The SR3 optimizer already solves using a hard threshold (or rather, converts between a hard threshold and an L-0 regularized problem). However, it uses cvxpy. Another point, the subclasses of BaseOptimizer conflate the abstract approach to sparse problem setup with the actual iterative solution to an optimization problem. This is the first PR that seeks to provide a similar class but with a different array package (well, there's another in development, but TBD).

My initial thought is that, as we're trying to also support jax arrays, is that it makes more sense to split the sparse problem setup from the actual iterative solution (subclasses of BaseOptimizer do both), and use torch/jax/cvxpy/numpy depending if the array type is jax.Array/torch.Tensor/cvxpy.Expression/numpy.ndarray. Since no feature libraries yet support anything other than numpy, I was planning to put work there first. As a very light approach, SR3._reduce() would check the type of array , and then dispatch to either it's existing code or yours depending upon the type of arrays. What are your thoughts on that?

…Also added new tests for the jax implementation.
@hvoss-techfak
Copy link
Author

hvoss-techfak commented Dec 6, 2025

Sounds good, although the SR3 implementations and the torch implementation tend to give somewhat different results for the same problem. Maybe just a verbose print when switching the optimizer would help users to understand what is happening.

I know some JAX so I now also added the same implementation in JAX. Everything is the same, except for the cadamw optimizer, that I had to remove as that optimizer is a bit of a pain to rewrite in JAX.

@Jacob-Stevens-Haas
Copy link
Member

I was not asking you to make a jax implementation, I was wondering whether it made sense to split how the problem is regularized from how the regularized problem is minimized. Since jax and torch (and numpy) all have very similar APIs, a single minimization approach may be able to serve multiple array types.

E.g.

def foo(x: np.ndarray):
    return np.do_thing(x)

Being replaced with

def foo(x: np.ndarray | torch.Tensor):
    array_pkg = np if isinstance(x, np.ndarray) else torch
    return array_pkg.do_thing(x)

A good example of this is your _soft_threshold and the existing _prox_l1. To that end, I would ask that you see how your code can integrate with SR3 first, so that we can limit the amount of redundant code.

There's also the question of cAdamW. I don't know what it is or why it does better. It feels like one step further than adding a torch optimizer. I'd say let's discuss that and the benchmarks after we see where the best place to plug this in is.

@Jacob-Stevens-Haas
Copy link
Member

PS can you clarify if/how you used AI in writing this code? It's fine if it was written with AI, but I want to know how much to ask of myself and how much to ask of you.

@Jacob-Stevens-Haas
Copy link
Member

PPS, we're changing the default branch from master to main, so you'll need to change the target of the PR to main

@hvoss-techfak
Copy link
Author

Hey, I simply added a jax version as I thought this would make things easier going forward and I wanted to write something in JAX again as it has been some time for me.

I mainly did the pytorch implementation (and now the jax implementation) as a learning experience to better understand how to find dynamic systems programmatically. Therefore, I generally don't write my own code with AI tools. The only place where I use ChatGPT is to add documentation for my code and to write tests if I actually do a Pull Request (like in this case).

I'm actually not quite sure about the best way to approach combining the different versions. Sure, the functions could simply be datatype aware and change depending on the inserted data type, but the actual solving also changes somewhat. The SR3 and torch/jax implementation generally give different results with the same input configurations:

System: rossler
Optimizer Score MSE Time (s) Complexity
SR3-L0 0.9909 3.6573e-04 0.0063 44
SR3-constrained 0.9909 3.6573e-04 0.0064 44
SR3-stable 0.5873 5.1498e-04 1.6874 60
TorchOptimizer 1.0000 9.0054e-10 0.0086 44
JaxOptimizer 1.0000 9.0054e-10 0.0139 44

Depending on whether someone uses numpy/jax/torch arrays they could therefore get very different results and I don't know if this is desirable?

The cAdamW algorithm is a somewhat changed version from this repository: https://github.com/kyleliang919/C-Optim. I found that in some cases in the benchmark it gave better results than normal AdamW and I succesfully used it in one of my other projects to decrease the amounts of needed iterations by roughly 20%. Sadly, the do not have a PYPI package that can simply be installed and used, so I added my current version from my other project in this pull request. Just for clarity and scope I could also simply remove it for now.

@hvoss-techfak
Copy link
Author

I also just checked and I don't think I can switch from the master to main branch without forking the project again. The only branch github allows me to track is the master branch as all other are not forked from the original project.

@Jacob-Stevens-Haas
Copy link
Member

I believe you can PR into any branch in the repo regardless of which branch you forked, but I don't think you can move the target of an existing PR.

I'm actually not quite sure about the best way to approach combining the different versions

I admit that I'm not either. It might make more sense to leave that to a future refactoring.

Enough people have asked about benchmarking that I've added benchmarking via ASV to the repo. But before I ask to refactor your benchmark into an ASV one, I admit I haven't thought about the acceptance criteria for methods like this. On one hand, in the past, all new approaches have had a published paper on system identification to back them up beyond a single benchmark. By that metric, this PR doesn't add something a substantially novel approach. On the other hand, academic publishing has too high a bar for research software, and I'm coming around to the idea (especially with benchmarking), that runtime matters for some users (e.g. #653).

I know scikit-lean/scipy each have written algorithm notability/acceptance criteria, but this is the first time we've had to consider it in pysindy. I want to think about the way this should work in general, whether it should happen via benchmarks or what. I'd love to hear your thoughts on this, though.

@hvoss-techfak
Copy link
Author

I believe you can PR into any branch in the repo regardless of which branch you forked, but I don't think you can move the target of an existing PR.

Yeah that makes sense. I'll probably split this PR then into two: A benchmark PR and a Torch/Jax PR as soon as it is clear what should be added.

I admit I haven't thought about the acceptance criteria for methods like this. On one hand, in the past, all new approaches have had a published paper on system identification to back them up beyond a single benchmark. By that metric, this PR doesn't add something a substantially novel approach.

For me at least, this PR was more a learning experience and a byproduct of a color mixing differential equation search, where I had some plastic filament and wanted to find out the underlying equation system if multiple layers of different colors are added upon each other. So it would not really be a problem for me if you would not add the optimizer. Although, I do think that the SR3 algorithm could benefit from a more powerful optimizer as right now the torch/jax implementation finds better results with less complexity in many cases.

I know scikit-lean/scipy each have written algorithm notability/acceptance criteria, but this is the first time we've had to consider it in pysindy. I want to think about the way this should work in general, whether it should happen via benchmarks or what. I'd love to hear your thoughts on this, though.

I would probably propose to add a large comprehensive benchmark that automatically tests all available algorithms against a large amount of systems with different complexity. In this way, the library would not only be usable for application purposes but also be a good first step for future researchers to try out their new algorithm without having to setup all other algorithms and benchmark them individually. They could simply add their version to a private version of pySINDY, publish the paper and then later add a PR to add it directly into the library. With this, PRs with new algorithms could be "judged" on their performance, either in terms of quality or speed and accepted/rejected based on clear quality criteria.

@Jacob-Stevens-Haas
Copy link
Member

Jacob-Stevens-Haas commented Dec 9, 2025

color mixing differential equation search, where I had some plastic filament and wanted to find out the underlying equation system if multiple layers of different colors are added upon each other.

Could you share more about this example? This is really interesting. I looked at the AutoForge link, but I didn't see that there was a use case of differential equations.

I do think that the SR3 algorithm could benefit from a more powerful optimizer as right now the torch/jax implementation finds better results with less complexity in many cases.

I'd agree with that.

They could simply add their version to a private version of pySINDY, publish the paper and then later add a PR to add it directly into the library

There are several options here: (a) 3rd party package that inherits from pysindy ABCs, (b) private fork of pysindy (c) 3rd party package, but pysindy provides plugin registration so users don't need to directly import from 3rd party package.


In order to cut through the decision making process, lets aim for just the torch or the jax optimizer at first. We'll try to separate the math from the implementation as much as possible, with the idea that this could support future work to fully pull backend (torch vs jax vs others) into a separate object later

I noticed that jax is a lot shorter - is that because it has less capability (other than CAdamW), or because it's just naturally less verbose? Also, although jax itself doesn't have optimizers, the ecosystem (optax) does. You might be able to to borrow from that.

@hvoss-techfak
Copy link
Author

Could you share more about this example? This is really interesting. I looked at the AutoForge link, but I didn't see that there was a use case of differential equations.

Sure, essentially I printed a lot of different 3d model color setups with different PLA filament layers on top of each other and measured the resulting colors after each layer. With this I now have a large excel file with a lot of different colors. The problem setup is essentially, that we have layers of different thickness and a "Transmission Distance" (How far light penetrates the filament until it is completely opaque) and from there we have to calculate a color mixing for the R+G+B colors to get a good approximation of the actual color behavior. Currently, I do this rather simply by using this formula I found with some hyper parameter optimization library:
o, A, k, b = -1.2416557e-02, 9.6407950e-01, 3.4103447e01, -4.1554203e00
opac = o + (A * torch.log1p(k * thick_ratio) + b * thick_ratio)
opac = torch.clamp(opac, 0.0, 1.0) # [L,H,W]
https://github.com/hvoss-techfak/AutoForge/blob/513d56a38460355f43d5d6a0ca7dd395cd84d441/src/autoforge/Helper/OptimizerHelper.py#L175C1-L178C50

This essentially only takes the opacity into account and not the actual color, which is not ideal. Therefore I'm currently using pysindy to find the actual system to model the RGB colors for each layer, dependent on Transmission distance, current layer height and overall layer height. I haven't pushed anything right now, as currently I still need to collect some more samples and don't want to change too much of the working system.

I noticed that jax is a lot shorter - is that because it has less capability (other than CAdamW), or because it's just naturally less verbose?

It is mainly less verbose in the initialization part. The algorithm is pretty much the same length, but in JAX you don't have to define as much Tensor information upfront.

Also, although jax itself doesn't have optimizers, the ecosystem (optax) does. You might be able to to borrow from that.

Yeah good point. I'll use the optax optimizers when refactoring this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants