Skip to content

arietul/PyTorch-PI_HC_MoE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch-PI_HC_MoE

PyTorch port of "Scaling Physics-Informed Hard Constraints with Mixture-of-Experts" (ICLR 2024)

Original paper: Chalapathi, Du & Krishnapriyan — ASK-Berkeley/physics-NNs-hard-constraints


What this is

A faithful, structurally identical PyTorch rewrite of the original JAX/Equinox codebase. Every module, class, and algorithm has a direct counterpart with the same name and purpose.

Original (JAX) This repo (PyTorch)
equinox.Module torch.nn.Module
jax.numpy torch
jax.vmap loop / torch.func.vmap
jax.jit torch.compile (optional)
optax optimisers torch.optim
optimistix / jaxopt LM/GN Theseus (Meta AI)
TensorFlow datasets torch.utils.data.DataLoader + h5py

Package structure

neptune_pt/
├── geometry/          # Grid, Function, mesh_utils (boundary conditions etc.)
├── pdes/              # AbstractPDE, DiffusionSorption1D, NavierStokes2D
├── models/
│   ├── modules/       # SpectralConv2d/3d, ICMollifier, activations
│   ├── fno2d.py       # 2D Fourier Neural Operator
│   ├── fno3d.py       # 3D Fourier Neural Operator
│   ├── constraints/   # HardConstraintLayer (Theseus LM/GN/L-BFGS)
│   ├── MoE/           # SpatialTemporalMoE router
│   └── sequential.py  # SequentialModel (backbone + constraint)
├── losses/            # MSE, PDELoss
├── datasets/          # DataLoaders for DS-1D and NS-2D (HDF5/npz)
├── trainers/          # Trainer, TrainingConfig, optimisers, checkpointing
└── callbacks/         # W&B heatmap/time-series logging, checkpoint I/O

Setup

pip install -e .
# or
pip install -r requirements.txt

For hard constraints (differentiable NLLS), Theseus must be installed:

pip install theseus-ai

For double precision (recommended with hard constraints):

USE_FLOAT64=1 python train.py ...

Data

Download the same datasets as the original paper:

  • Diffusion-Sorption 1D – from PDEBench (1D_diff-sorp_NA_NA.h5)
  • Navier-Stokes 2D – same Google Drive link as the original repo

Update data_root in the corresponding configs/datasets/*.yaml.


Training

Diffusion-Sorption – soft constraint (FNO2D baseline)

python train.py \
  --model_config    configs/models/fno2d-64-8m.yaml \
  --dataset_config  configs/datasets/diffusion-sorption.yaml \
  --training_config configs/training/ds.yaml

Diffusion-Sorption – hard constraint (Theseus LM)

python train.py \
  --model_config    configs/models/fno2d-64-8m-hard-constraint.yaml \
  --dataset_config  configs/datasets/diffusion-sorption.yaml \
  --training_config configs/training/ds.yaml

Diffusion-Sorption – hard constraint + MoE

python train.py \
  --model_config    configs/models/fno2d-64-8m-hard-constraint-moe.yaml \
  --dataset_config  configs/datasets/diffusion-sorption.yaml \
  --training_config configs/training/ds.yaml

Navier-Stokes – hard constraint + MoE (FNO3D)

python train.py \
  --model_config    configs/models/fno3d-64-8m-hard-constraint-moe.yaml \
  --dataset_config  configs/datasets/navier-stokes.yaml \
  --training_config configs/training/ns.yaml

Resume from checkpoint

python resume.py \
  --model_config    configs/models/fno2d-64-8m-hard-constraint.yaml \
  --dataset_config  configs/datasets/diffusion-sorption.yaml \
  --training_config configs/training/ds.yaml \
  --resume          ./checkpoints/best.pt

Testing

python test.py \
  --model_config    configs/models/fno2d-64-8m-hard-constraint.yaml \
  --dataset_config  configs/datasets/diffusion-sorption.yaml \
  --training_config configs/training/ds.yaml \
  --resume          ./checkpoints/best.pt

Key design choices

Hard constraints via Theseus

The hard constraint layer finds basis weights w that minimise:

F(w) = [PDE_residual(u(w));  IC_loss(w);  BC_loss(w);  ridge(w)]
||F(w)||² -> min

In the original code this uses optimistix.LevenbergMarquardt (JAX). Here we use theseus.LevenbergMarquardt which:

  • Is fully differentiable through the optimisation (implicit differentiation)
  • Runs on GPU via PyTorch
  • Integrates directly with the standard PyTorch autograd graph

Mixture of Experts

Domain decomposition mirrors the original:

  • dimension_partition splits (nt, nx[, ny], d) along temporal/spatial axes
  • Each expert runs an independent hard constraint solve on its sub-domain
  • dimension_unpartition reassembles the full solution

Citation

@inproceedings{chalapathi2024scaling,
  title   = {Scaling Physics-Informed Hard Constraints with Mixture-of-Experts},
  author  = {Chalapathi, Nithin and Du, Yiheng and Krishnapriyan, Aditi S.},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year    = {2024}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages