PyTorch port of "Scaling Physics-Informed Hard Constraints with Mixture-of-Experts" (ICLR 2024)
Original paper: Chalapathi, Du & Krishnapriyan — ASK-Berkeley/physics-NNs-hard-constraints
A faithful, structurally identical PyTorch rewrite of the original JAX/Equinox codebase. Every module, class, and algorithm has a direct counterpart with the same name and purpose.
| Original (JAX) | This repo (PyTorch) |
|---|---|
equinox.Module |
torch.nn.Module |
jax.numpy |
torch |
jax.vmap |
loop / torch.func.vmap |
jax.jit |
torch.compile (optional) |
optax optimisers |
torch.optim |
optimistix / jaxopt LM/GN |
Theseus (Meta AI) |
| TensorFlow datasets | torch.utils.data.DataLoader + h5py |
neptune_pt/
├── geometry/ # Grid, Function, mesh_utils (boundary conditions etc.)
├── pdes/ # AbstractPDE, DiffusionSorption1D, NavierStokes2D
├── models/
│ ├── modules/ # SpectralConv2d/3d, ICMollifier, activations
│ ├── fno2d.py # 2D Fourier Neural Operator
│ ├── fno3d.py # 3D Fourier Neural Operator
│ ├── constraints/ # HardConstraintLayer (Theseus LM/GN/L-BFGS)
│ ├── MoE/ # SpatialTemporalMoE router
│ └── sequential.py # SequentialModel (backbone + constraint)
├── losses/ # MSE, PDELoss
├── datasets/ # DataLoaders for DS-1D and NS-2D (HDF5/npz)
├── trainers/ # Trainer, TrainingConfig, optimisers, checkpointing
└── callbacks/ # W&B heatmap/time-series logging, checkpoint I/O
pip install -e .
# or
pip install -r requirements.txtFor hard constraints (differentiable NLLS), Theseus must be installed:
pip install theseus-aiFor double precision (recommended with hard constraints):
USE_FLOAT64=1 python train.py ...Download the same datasets as the original paper:
- Diffusion-Sorption 1D – from PDEBench
(
1D_diff-sorp_NA_NA.h5) - Navier-Stokes 2D – same Google Drive link as the original repo
Update data_root in the corresponding configs/datasets/*.yaml.
python train.py \
--model_config configs/models/fno2d-64-8m.yaml \
--dataset_config configs/datasets/diffusion-sorption.yaml \
--training_config configs/training/ds.yamlpython train.py \
--model_config configs/models/fno2d-64-8m-hard-constraint.yaml \
--dataset_config configs/datasets/diffusion-sorption.yaml \
--training_config configs/training/ds.yamlpython train.py \
--model_config configs/models/fno2d-64-8m-hard-constraint-moe.yaml \
--dataset_config configs/datasets/diffusion-sorption.yaml \
--training_config configs/training/ds.yamlpython train.py \
--model_config configs/models/fno3d-64-8m-hard-constraint-moe.yaml \
--dataset_config configs/datasets/navier-stokes.yaml \
--training_config configs/training/ns.yamlpython resume.py \
--model_config configs/models/fno2d-64-8m-hard-constraint.yaml \
--dataset_config configs/datasets/diffusion-sorption.yaml \
--training_config configs/training/ds.yaml \
--resume ./checkpoints/best.ptpython test.py \
--model_config configs/models/fno2d-64-8m-hard-constraint.yaml \
--dataset_config configs/datasets/diffusion-sorption.yaml \
--training_config configs/training/ds.yaml \
--resume ./checkpoints/best.ptThe hard constraint layer finds basis weights w that minimise:
F(w) = [PDE_residual(u(w)); IC_loss(w); BC_loss(w); ridge(w)]
||F(w)||² -> min
In the original code this uses optimistix.LevenbergMarquardt (JAX).
Here we use theseus.LevenbergMarquardt which:
- Is fully differentiable through the optimisation (implicit differentiation)
- Runs on GPU via PyTorch
- Integrates directly with the standard PyTorch autograd graph
Domain decomposition mirrors the original:
dimension_partitionsplits(nt, nx[, ny], d)along temporal/spatial axes- Each expert runs an independent hard constraint solve on its sub-domain
dimension_unpartitionreassembles the full solution
@inproceedings{chalapathi2024scaling,
title = {Scaling Physics-Informed Hard Constraints with Mixture-of-Experts},
author = {Chalapathi, Nithin and Du, Yiheng and Krishnapriyan, Aditi S.},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2024}
}