diff --git a/.codespellrc b/.codespellrc index 08ef4be1..c8e7df41 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,2 +1,2 @@ [codespell] -ignore-words-list = COO,Mater,ket +ignore-words-list = COO,Mater,ket,nd,te diff --git a/.github/workflows/lammps-build.yml b/.github/workflows/lammps-build.yml new file mode 100644 index 00000000..9966f01f --- /dev/null +++ b/.github/workflows/lammps-build.yml @@ -0,0 +1,320 @@ +name: LAMMPS pair_matgl build + +# Builds LAMMPS with PKG_ML-MATGL inside the materialyzeai/lammps image +# and runs a single-point parity check against the Python ASE calculator. +# +# The image is expected to ship: +# * a LAMMPS source checkout (path discovered at run time), +# * libtorch (CXX11 ABI), +# * gcc/g++/cmake/python3. +# If any of those are missing the workflow falls back to downloading them. +# +# Kokkos / GPU variant is *not* exercised here — GitHub-hosted runners have +# no GPU. Phase-3 hardware testing is delegated to a self-hosted CUDA +# runner (TODO). + +on: + push: + branches: ["*"] + paths: + - "lammps/**" + - "src/matgl/ext/_lammps.py" + - "src/matgl/ext/lammps.py" + - ".github/workflows/lammps-build.yml" + pull_request: + branches: [main] + paths: + - "lammps/**" + - "src/matgl/ext/_lammps.py" + - "src/matgl/ext/lammps.py" + - ".github/workflows/lammps-build.yml" + workflow_dispatch: + +jobs: + build_pair_matgl_cpu: + name: pair_matgl (CPU) + runs-on: ubuntu-latest + + container: + image: materialyzeai/lammps:latest + + env: + LIBTORCH_VERSION: "2.5.1" + + steps: + - name: Install build + CI tooling + # The materialyzeai/lammps image is a runtime image — it ships the + # `lmp` binary but none of the build toolchain. We install + # everything we need (compiler, cmake, ninja, fftw headers, + # curl/unzip/git for downloads & checkout, python venv for uv). + shell: bash + run: | + if command -v apt-get >/dev/null 2>&1; then + apt-get update -qq + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl unzip git ca-certificates xz-utils \ + build-essential gcc g++ cmake ninja-build pkg-config \ + libfftw3-dev libpng-dev libjpeg-dev \ + python3 python3-venv python3-dev + elif command -v dnf >/dev/null 2>&1; then + dnf install -y curl unzip git ca-certificates xz \ + gcc gcc-c++ cmake ninja-build pkgconf-pkg-config make \ + fftw-devel libpng-devel libjpeg-devel \ + python3 python3-devel + elif command -v apk >/dev/null 2>&1; then + apk add --no-cache curl unzip git ca-certificates xz \ + build-base cmake samurai pkgconfig \ + fftw-dev libpng-dev jpeg-dev \ + python3 python3-dev + else + echo "No supported package manager found in image" >&2 + exit 1 + fi + echo "Installed compiler: $(gcc --version | head -1)" + echo "Installed cmake: $(cmake --version | head -1)" + + - name: Checkout matgl + uses: actions/checkout@v4 + with: + path: matgl + + - name: Probe image for LAMMPS source + libtorch + id: probe + shell: bash + run: | + set -e + # LAMMPS source — check the usual suspects. + for cand in /opt/lammps /usr/local/src/lammps /lammps "$HOME/lammps"; do + if [ -f "$cand/cmake/CMakeLists.txt" ]; then + echo "lammps_src=$cand" >> "$GITHUB_OUTPUT" + echo "Found LAMMPS source at: $cand" + break + fi + done + # libtorch — check the usual suspects. + for cand in /opt/libtorch /usr/local/libtorch /libtorch "$HOME/libtorch"; do + if [ -f "$cand/share/cmake/Torch/TorchConfig.cmake" ]; then + echo "libtorch=$cand" >> "$GITHUB_OUTPUT" + echo "Found libtorch at: $cand" + break + fi + done + + - name: Cache libtorch (fallback) + if: steps.probe.outputs.libtorch == '' + id: cache-libtorch + uses: actions/cache@v4 + with: + path: libtorch + key: libtorch-${{ env.LIBTORCH_VERSION }}-cpu-cxx11abi + + - name: Download libtorch (fallback) + if: steps.probe.outputs.libtorch == '' && steps.cache-libtorch.outputs.cache-hit != 'true' + shell: bash + run: | + curl -L -o libtorch.zip \ + "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip" + unzip -q libtorch.zip + rm libtorch.zip + + - name: Resolve libtorch path + id: libtorch + shell: bash + run: | + if [ -n "${{ steps.probe.outputs.libtorch }}" ]; then + echo "path=${{ steps.probe.outputs.libtorch }}" >> "$GITHUB_OUTPUT" + else + echo "path=${GITHUB_WORKSPACE}/libtorch" >> "$GITHUB_OUTPUT" + fi + + - name: Clone LAMMPS (fallback) + if: steps.probe.outputs.lammps_src == '' + shell: bash + run: | + git clone --depth 1 --branch develop \ + https://github.com/lammps/lammps.git lammps_src + + - name: Resolve LAMMPS source path + id: lammps + shell: bash + run: | + if [ -n "${{ steps.probe.outputs.lammps_src }}" ]; then + echo "path=${{ steps.probe.outputs.lammps_src }}" >> "$GITHUB_OUTPUT" + else + echo "path=${GITHUB_WORKSPACE}/lammps_src" >> "$GITHUB_OUTPUT" + fi + + - name: Install Python deps (uv) for parity reference + shell: bash + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" + cd matgl + uv venv --python 3.12 + uv pip install -e . + uv pip install pytest + + - name: Export tiny LAMMPS-loadable model + Python reference + shell: bash + run: | + export PATH="$HOME/.local/bin:$PATH" + cd matgl + mkdir -p ../lammps_artifacts + uv run python - <<'PY' + import json + import numpy as np + import torch + from pymatgen.core import Lattice, Structure + from pymatgen.optimization.neighbors import find_points_in_spheres + from matgl.apps._pes_pyg import Potential + from matgl.ext._lammps import LAMMPSMatGLModel + from matgl.models._tensornet_pyg import TensorNet + + torch.manual_seed(0) + m = TensorNet( + element_types=("Mo", "S"), + is_intensive=False, + units=16, nblocks=1, num_rbf=8, + cutoff=4.0, use_warp=False, rbf_type="Gaussian", + ) + p = Potential(model=m, calc_forces=True, calc_stresses=True) + p.eval() + w = LAMMPSMatGLModel(potential=p, dtype=torch.float32) + w.eval() + torch.jit.script(w).save("../lammps_artifacts/model.pt") + + struct = Structure( + Lattice.cubic(4.5), + ["Mo", "S", "Mo", "S"], + [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0, 0.25], [0, 0.5, 0.75]], + ) + src, dst, images, dist = find_points_in_spheres( + struct.cart_coords, struct.cart_coords, r=4.0, + pbc=np.array([1, 1, 1], dtype=np.int64), + lattice=np.array(struct.lattice.matrix), tol=1e-8, + ) + keep = (src != dst) | (dist > 1e-8) + src, dst, images = src[keep], dst[keep], images[keep] + pos = torch.tensor(struct.cart_coords, dtype=torch.float32) + eidx = torch.tensor(np.stack([src, dst]), dtype=torch.long) + ushifts = torch.tensor(images, dtype=torch.long) + cell = torch.tensor(np.array(struct.lattice.matrix), dtype=torch.float32) + z = torch.tensor([s.specie.Z for s in struct], dtype=torch.long) + local = torch.ones(len(struct), dtype=torch.bool) + out = w(pos, eidx, ushifts, cell, z, local, True) + ref = { + "energy": float(out["total_energy_local"].item()), + "forces": out["forces"].detach().tolist(), + } + with open("../lammps_artifacts/reference.json", "w") as fh: + json.dump(ref, fh, indent=2) + print(json.dumps(ref, indent=2)) + PY + + - name: Drop ML-MATGL sources into LAMMPS src/ + # LAMMPS' style-header generator only scans top-level src/*.h plus + # *enabled* package subdirs, and ML-MATGL is not in LAMMPS' standard + # package list. Copy the .cpp/.h directly into src/ so the + # PairStyle(matgl, ...) macro gets picked up by style_pair.h, then + # also include our cmake snippet so libtorch ends up on the link line. + shell: bash + env: + LAMMPS_SRC: ${{ steps.lammps.outputs.path }} + run: | + cp matgl/lammps/src/ML-MATGL/pair_matgl.cpp "${LAMMPS_SRC}/src/" + cp matgl/lammps/src/ML-MATGL/pair_matgl.h "${LAMMPS_SRC}/src/" + if ! grep -q ML-MATGL.cmake "${LAMMPS_SRC}/cmake/CMakeLists.txt"; then + echo "include(${GITHUB_WORKSPACE}/matgl/lammps/cmake/ML-MATGL-CI.cmake)" \ + >> "${LAMMPS_SRC}/cmake/CMakeLists.txt" + fi + # Generate a minimal CI-only cmake fragment that just links Torch; + # the regular ML-MATGL.cmake assumes sources live in a subdir. + cat > matgl/lammps/cmake/ML-MATGL-CI.cmake <<'CM' + find_package(Torch REQUIRED) + target_compile_features(lammps PRIVATE cxx_std_17) + target_link_libraries(lammps PRIVATE ${TORCH_LIBRARIES}) + if(DEFINED TORCH_CXX_FLAGS) + set_property(TARGET lammps APPEND_STRING + PROPERTY COMPILE_FLAGS " ${TORCH_CXX_FLAGS}") + endif() + message(STATUS "ML-MATGL (CI): linked against TORCH_LIBRARIES=${TORCH_LIBRARIES}") + CM + + - name: Configure LAMMPS + shell: bash + env: + LAMMPS_SRC: ${{ steps.lammps.outputs.path }} + LIBTORCH: ${{ steps.libtorch.outputs.path }} + run: | + BUILD_DIR="${GITHUB_WORKSPACE}/lammps_build" + rm -rf "${BUILD_DIR}" + # The CPU libtorch ships a broken Caffe2/MKL include dir reference + # in its CMake config; supply a harmless existing path so the + # generator step doesn't fail. The CPU build doesn't actually need + # MKL headers. + cmake -B "${BUILD_DIR}" -S "${LAMMPS_SRC}/cmake" \ + -G Ninja \ + -D PKG_ML-MATGL=ON \ + -D CMAKE_PREFIX_PATH="${LIBTORCH}" \ + -D MKL_INCLUDE_DIR=/usr/include \ + -D CMAKE_BUILD_TYPE=Release \ + -D BUILD_MPI=OFF \ + -D BUILD_OMP=ON \ + -D CMAKE_CXX_STANDARD=17 + + - name: Build LAMMPS + shell: bash + run: | + cmake --build "${GITHUB_WORKSPACE}/lammps_build" -j 2 + + - name: Run pair_matgl single-point deck + shell: bash + env: + LIBTORCH: ${{ steps.libtorch.outputs.path }} + run: | + export LD_LIBRARY_PATH="${LIBTORCH}/lib:${LD_LIBRARY_PATH:-}" + cp lammps_artifacts/model.pt matgl/lammps/tests/model.pt + cd matgl/lammps/tests + "${GITHUB_WORKSPACE}/lammps_build/lmp" -in in.matgl_si | tee log.lammps + + - name: Diff LAMMPS energy against Python reference + shell: bash + run: | + export PATH="$HOME/.local/bin:$PATH" + cd matgl + uv run python - <<'PY' + import json + import sys + + ref = json.load(open("../lammps_artifacts/reference.json")) + log = open("lammps/tests/log.lammps").read() + # thermo_style for the test deck is "step pe fx fy fz pxx pyy pzz". + # With ``run 0`` LAMMPS prints one numeric row; pull its second column. + pe = None + for line in log.splitlines(): + cols = line.split() + if len(cols) >= 2 and cols[0].lstrip("-").isdigit(): + try: + pe = float(cols[1]) + except ValueError: + continue + assert pe is not None, "Could not find PotEng row in log.lammps" + ref_e = ref["energy"] + diff = abs(pe - ref_e) + print(f"LAMMPS PotEng = {pe!r}, Python ref = {ref_e!r}, diff = {diff:.3e}") + if diff > 1e-3: + print("ENERGY MISMATCH!", file=sys.stderr) + sys.exit(1) + PY + + - name: Upload artifacts on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: lammps-debug + path: | + matgl/lammps/tests/log.lammps + lammps_artifacts/ + lammps_build/CMakeFiles/CMakeOutput.log + lammps_build/CMakeFiles/CMakeError.log + if-no-files-found: ignore diff --git a/lammps/README.md b/lammps/README.md new file mode 100644 index 00000000..5be45044 --- /dev/null +++ b/lammps/README.md @@ -0,0 +1,203 @@ +# MatGL → LAMMPS pair_style + +`pair_matgl` is a LAMMPS pair style that loads a TorchScript-compiled +**MatGL TensorNet** PES (PyG backend, no-Warp, extensive head) and uses +LibTorch to evaluate energies, forces, and the virial tensor on every +timestep. + +This directory ships: + +- `src/ML-MATGL/pair_matgl.{cpp,h}` — the CPU/serial pair style. +- `src/KOKKOS/pair_matgl_kokkos.{cpp,h}` — the Kokkos GPU/host variant + (`pair_style matgl/kk`). +- `cmake/ML-MATGL.cmake` and `cmake/ML-MATGL-KOKKOS.cmake` — drop-in + CMake snippets. +- `tests/in.matgl_si` — sample input deck for a single-point parity check. + +The Python side (one repo up) ships `mgl create-lammps-model`, which +produces the `.pt` artifact these pair styles consume. + +> **Status — Phases 2 + 3 of the MatGL LAMMPS-Kokkos plugin.** CPU CI +> exists (`.github/workflows/lammps-build.yml`); GPU runs require a +> CUDA-capable runner (not in CI yet). + +## Building + +### 1. Export a LAMMPS-loadable model + +```bash +# From your matgl checkout: +uv run mgl create-lammps-model \ + -m materialyze/TensorNet-MatPES-r2SCAN \ + -o tensornet_matpes_r2scan.pt \ + --dtype float32 +``` + +The CLI prints `r_max`, `n_species`, the dtype, and the species list — all +of which you'll need for `pair_coeff`. + +### 2. Build LAMMPS with the package + +Drop the package into a stock LAMMPS source tree and configure: + +```bash +# 1) Copy or symlink the source files. +ln -s /path/to/matgl/lammps/src/ML-MATGL /src/ML-MATGL + +# 2) Tell LAMMPS' CMake about the package. +echo 'include(/path/to/matgl/lammps/cmake/ML-MATGL.cmake)' \ + >> /cmake/CMakeLists.txt + +# 3) Configure + build. Match libtorch's CXX11 ABI to LAMMPS'. +cmake -B build -S /cmake \ + -D PKG_ML-MATGL=ON \ + -D CMAKE_PREFIX_PATH=/path/to/libtorch \ + -D CMAKE_BUILD_TYPE=Release \ + -D BUILD_MPI=ON +cmake --build build -j 8 +``` + +### 2b. Build the Kokkos GPU variant + +To get the `matgl/kk` pair style, also enable Kokkos and append the +matching snippet to LAMMPS' CMake. CUDA example for an Ampere card +(A100/A30): + +```bash +echo 'include(/path/to/matgl/lammps/cmake/ML-MATGL-KOKKOS.cmake)' \ + >> /cmake/CMakeLists.txt + +cmake -B build -S /cmake \ + -D PKG_ML-MATGL=ON \ + -D PKG_KOKKOS=ON \ + -D Kokkos_ENABLE_CUDA=ON \ + -D Kokkos_ARCH_AMPERE80=ON \ + -D CMAKE_PREFIX_PATH=/path/to/libtorch \ + -D CMAKE_CXX_COMPILER=/lib/kokkos/bin/nvcc_wrapper \ + -D CMAKE_BUILD_TYPE=Release +cmake --build build -j 8 +``` + +Run with: + +```bash +mpirun -n 1 build/lmp -k on g 1 -sf kk -in in.matgl_si +``` + +`-sf kk` makes LAMMPS prefer Kokkos pair styles, so `pair_style matgl` +in your input deck dispatches to `matgl/kk` automatically. If you'd +rather force it explicitly, write `pair_style matgl/kk` instead. + +**Single-GPU only.** Multi-rank Kokkos with libtorch is unreliable +(MACE issues #1294 and #322); the package emits a CMake message making +this explicit. + +Tested with: + +- LibTorch 2.2.x – 2.5.x (CXX11 ABI, CPU build). +- LAMMPS develop branch (Aug 2024 or newer for the `add_request` / + `REQ_GHOST` neighbor-list API). +- C++17, MPI optional. + +## LAMMPS input syntax + +```lammps +units metal +atom_style atomic +atom_modify map yes # required: pair_matgl needs the atom map +newton on # required: ghost contributions + +pair_style matgl +pair_coeff * * tensornet_matpes_r2scan.pt Si C O +``` + +`pair_coeff` arguments after the `.pt` path are **species symbols** in +LAMMPS atom-type order: type 1 = first symbol, type 2 = second, … + +The cutoff (`r_max`) is read from the model — you don't pass it. + +### Optional pair_style flags + +```lammps +pair_style matgl no_domain_decomposition +``` + +Reserved for future single-rank optimisations (mirrors the MACE flag). +Currently a no-op. + +## Limitations + +- **No per-atom energies / virials.** `eflag_atom`, `vflag_atom`, and + `compute … pe/atom` will error. The model returns a single + `total_energy_local` scalar plus a 3×3 virial tensor; per-atom + decompositions would require a different export. +- **`atom_style atomic` only** for now. Charged systems aren't supported + (the model has no charge head). +- **TorchScript artifacts are dtype-specific.** Re-run + `mgl create-lammps-model --dtype float64` to get a double-precision + model; mixing dtypes between LAMMPS and the model will error at load + time. +- **Multi-rank**: works for CPU MPI, but each rank loads the model + independently (memory adds up). The `data_mean` buffer baked into the + TorchScript is added once per rank — keep `data_mean = 0` (the default + for trained MatGL PES models). Non-zero `data_mean` will over-count + proportionally to the number of ranks. +- **No restart support.** The model lives on disk; `restart` files don't + capture the path. Re-issue `pair_style` / `pair_coeff` after a restart. +- **TensorNet only.** M3GNet, CHGNet, MEGNet, SO3Net, QET are DGL-only + in the matgl repo and would need PyG ports first. + +## Continuous integration + +`.github/workflows/lammps-build.yml` builds the **CPU** pair style on +every push that touches the `lammps/` tree, the Python wrapper, or the +workflow itself. The job runs inside the `lammps/lammps-build:ubuntu_latest` +public Docker image, downloads a CXX11-ABI libtorch, clones LAMMPS at a +pinned tag, builds with `PKG_ML-MATGL=ON`, exports a tiny in-tree model +through `LAMMPSMatGLModel`, runs the `in.matgl_si` deck, and diffs the +LAMMPS energy against the Python reference. + +The Kokkos variant is **not** exercised in CI today — GitHub-hosted +runners have no GPU. Hardware-accelerated CI is on the Phase-3 follow-up +list and likely lives on a self-hosted CUDA runner. + +## Verifying a build + +```bash +cd lammps/tests +/build/lmp -in in.matgl_si +``` + +The test deck prints energy, forces, and stress on a small Si supercell. +Compare against the Python reference: + +```bash +uv run python tests/python_reference.py # in this directory +``` + +Energies should match within `1e-5 eV`, forces within `1e-4 eV/Å`, and +stresses (when nonzero) within `1e-3 GPa`. + +## Implementation notes + +- The pair style requests a **full neighbor list with ghost atoms** + (`REQ_FULL | REQ_GHOST`). The model expects edge indices that span both + owned and ghost atoms. +- Bond vectors are computed from LAMMPS' already-imaged ghost positions, + so `unit_shifts` is always zero — the strain-based stress trick still + works because the strain is applied to all atomic positions (owned and + ghost) on the model side. +- Forces are accumulated for **all** atoms (owned + ghost). LAMMPS' usual + `comm->reverse_comm` step then sums ghost contributions back to the + rank that owns each atom. This requires `newton on`. +- Virials are written into the global `virial[6]` array directly. We set + `no_virial_fdotr_compute = 1` in the constructor so LAMMPS doesn't + recompute the virial from forces. + +## Reference + +Plan and design notes: +[`develop-a-kokkos-plugin-eventual-hare.md`](https://github.com/materialyzeai/matgl/tree/lammps). + +The Python wrapper is documented inline at +`src/matgl/ext/_lammps.py` in the matgl repo. diff --git a/lammps/cmake/ML-MATGL-KOKKOS.cmake b/lammps/cmake/ML-MATGL-KOKKOS.cmake new file mode 100644 index 00000000..0728b665 --- /dev/null +++ b/lammps/cmake/ML-MATGL-KOKKOS.cmake @@ -0,0 +1,43 @@ +# ML-MATGL Kokkos variant — drop-in CMake snippet. +# +# Layered on top of ML-MATGL.cmake: include() this *after* the base snippet, +# OR set PKG_ML-MATGL=ON and PKG_KOKKOS=ON together. +# +# Usage (from a stock LAMMPS source tree): +# cmake -B build \ +# -D PKG_ML-MATGL=ON -D PKG_KOKKOS=ON \ +# -D Kokkos_ENABLE_CUDA=ON \ +# -D Kokkos_ARCH_AMPERE80=ON \ +# -D CMAKE_PREFIX_PATH=/path/to/libtorch \ +# -D CMAKE_CXX_COMPILER=$LAMMPS/lib/kokkos/bin/nvcc_wrapper \ +# +# +# The `pair_matgl/kk` style is registered via the standard LAMMPS Kokkos +# pair-style macro so users invoke it with `pair_style matgl/kk` or by +# launching LAMMPS with `-sf kk -k on g 1`. + +if(NOT PKG_ML-MATGL OR NOT PKG_KOKKOS) + return() +endif() + +if(NOT DEFINED ML_MATGL_KOKKOS_DIR) + get_filename_component(ML_MATGL_KOKKOS_DIR + "${CMAKE_CURRENT_LIST_DIR}/../src/KOKKOS" ABSOLUTE) +endif() + +if(NOT EXISTS "${ML_MATGL_KOKKOS_DIR}/pair_matgl_kokkos.cpp") + message(FATAL_ERROR + "ML-MATGL-KOKKOS source not found at ${ML_MATGL_KOKKOS_DIR}. " + "Set -DML_MATGL_KOKKOS_DIR=.") +endif() + +file(GLOB ML_MATGL_KOKKOS_SOURCES "${ML_MATGL_KOKKOS_DIR}/*.cpp") + +target_sources(lammps PRIVATE ${ML_MATGL_KOKKOS_SOURCES}) +target_include_directories(lammps PRIVATE ${ML_MATGL_KOKKOS_DIR}) + +# Single-GPU only: warn loudly. MACE upstream issues #1294 and #322 cover +# the multi-rank-with-libtorch breakage we inherit. +message(STATUS + "ML-MATGL-KOKKOS: enabled. Single-GPU runs only — multi-rank Kokkos with " + "libtorch is unreliable (see MACE issues #1294, #322).") diff --git a/lammps/cmake/ML-MATGL.cmake b/lammps/cmake/ML-MATGL.cmake new file mode 100644 index 00000000..bfb209e6 --- /dev/null +++ b/lammps/cmake/ML-MATGL.cmake @@ -0,0 +1,67 @@ +# ML-MATGL package — drop-in CMake snippet for a stock LAMMPS source tree. +# +# Usage (from a stock LAMMPS source tree): +# 1. Copy or symlink lammps/src/ML-MATGL → /src/ML-MATGL +# 2. Append to /cmake/CMakeLists.txt (anywhere after the `set(STANDARD_PACKAGES …)` +# block): +# include(/path/to/matgl/lammps/cmake/ML-MATGL.cmake) +# 3. Configure with: +# cmake -B build \ +# -D PKG_ML-MATGL=ON \ +# -D CMAKE_PREFIX_PATH=/path/to/libtorch \ +# -D CMAKE_BUILD_TYPE=Release \ +# +# +# CMake variables consumed: +# PKG_ML-MATGL - turn the package on/off (default OFF). +# CMAKE_PREFIX_PATH - must point at a libtorch install (CXX11 ABI build). +# ML_MATGL_DIR - override path to lammps/src/ML-MATGL (defaults to +# ${CMAKE_CURRENT_LIST_DIR}/../src/ML-MATGL). + +option(PKG_ML-MATGL "Build the matgl pair_style backed by libtorch" OFF) + +if(NOT PKG_ML-MATGL) + return() +endif() + +# Locate the source directory. +if(NOT DEFINED ML_MATGL_DIR) + get_filename_component(ML_MATGL_DIR + "${CMAKE_CURRENT_LIST_DIR}/../src/ML-MATGL" ABSOLUTE) +endif() + +if(NOT EXISTS "${ML_MATGL_DIR}/pair_matgl.cpp") + message(FATAL_ERROR + "ML-MATGL source not found at ${ML_MATGL_DIR}. " + "Set -DML_MATGL_DIR=.") +endif() + +# Pull in libtorch. +find_package(Torch REQUIRED) +if(NOT TORCH_LIBRARIES) + message(FATAL_ERROR + "find_package(Torch) succeeded but TORCH_LIBRARIES is empty. " + "Did you set CMAKE_PREFIX_PATH to a libtorch install?") +endif() + +# Compose the source list. +file(GLOB ML_MATGL_SOURCES "${ML_MATGL_DIR}/*.cpp") + +# Hook into the LAMMPS build. This file is included from +# /cmake/CMakeLists.txt; the `lammps` target already exists by then. +target_sources(lammps PRIVATE ${ML_MATGL_SOURCES}) +target_include_directories(lammps PRIVATE ${ML_MATGL_DIR}) +target_compile_features(lammps PRIVATE cxx_std_17) +target_link_libraries(lammps PRIVATE ${TORCH_LIBRARIES}) + +# Make sure libtorch's headers come ahead of any system Eigen/torch shims. +target_include_directories(lammps PRIVATE ${TORCH_INCLUDE_DIRS}) + +# LibTorch ships with -D_GLIBCXX_USE_CXX11_ABI=…; propagate it so consumers +# (e.g. KOKKOS in Phase 3) see the same ABI. +if(DEFINED TORCH_CXX_FLAGS) + set_property(TARGET lammps APPEND_STRING PROPERTY COMPILE_FLAGS " ${TORCH_CXX_FLAGS}") +endif() + +message(STATUS "ML-MATGL: enabled, sources from ${ML_MATGL_DIR}") +message(STATUS "ML-MATGL: linking against TORCH_LIBRARIES=${TORCH_LIBRARIES}") diff --git a/lammps/src/KOKKOS/pair_matgl_kokkos.cpp b/lammps/src/KOKKOS/pair_matgl_kokkos.cpp new file mode 100644 index 00000000..67ea78d7 --- /dev/null +++ b/lammps/src/KOKKOS/pair_matgl_kokkos.cpp @@ -0,0 +1,357 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + + pair_matgl/kk: Kokkos variant of pair_matgl. See pair_matgl_kokkos.h. + + Notes for readers familiar with pair_mace_kokkos.cpp: the overall control + flow is identical (count edges, scan, fill, hand off to libtorch, scatter + forces). The only matgl-specific parts are: + + * the model forward signature (we pass `compute_virials: bool`), + * the output dict keys (`total_energy_local`, `forces`, `virials`), + * `unit_shifts` is held at zero because LAMMPS hands us already-imaged + ghost positions; the strain-grad still propagates correctly because + the wrapper applies the strain to *every* position (owned and ghost). + + Caveats on multi-rank Kokkos with libtorch (mirroring MACE): + * Single-GPU runs are the supported configuration. + * `mpirun -n 1 lmp -k on g 1 -sf kk` works. + * Multi-rank Kokkos with libtorch is unreliable (see MACE issues #1294, + #322); this pair style does not attempt to fix that. +------------------------------------------------------------------------- */ + +#include "pair_matgl_kokkos.h" + +#include "atom_kokkos.h" +#include "atom_masks.h" +#include "comm.h" +#include "domain.h" +#include "error.h" +#include "force.h" +#include "kokkos.h" +#include "memory_kokkos.h" +#include "neigh_request.h" +#include "neighbor_kokkos.h" + +#include +#include + +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +template +PairMATGLKokkos::PairMATGLKokkos(LAMMPS *lmp) : PairMATGL(lmp) +{ + respa_enable = 0; + kokkosable = 1; + atomKK = (AtomKokkos *) atom; + execution_space = ExecutionSpaceFromDevice::space; + + datamask_read = X_MASK | F_MASK | TYPE_MASK | ENERGY_MASK | VIRIAL_MASK; + datamask_modify = F_MASK | ENERGY_MASK | VIRIAL_MASK; +} + +/* ---------------------------------------------------------------------- */ + +template +PairMATGLKokkos::~PairMATGLKokkos() = default; + +/* ---------------------------------------------------------------------- */ + +template +void PairMATGLKokkos::init_style() +{ + PairMATGL::init_style(); + + // Replace the host neighbor request with a Kokkos one. + auto request = neighbor->find_request(this); + request->set_kokkos_host(std::is_same::value && + !std::is_same::value); + request->set_kokkos_device(std::is_same::value); + + // Pick the matching libtorch device. + if (std::is_same::value) { + const int gpu = lmp->kokkos->ngpus > 0 ? lmp->kokkos->local_rank : 0; + torch_device_ = torch::Device(torch::kCUDA, gpu); + } else { + torch_device_ = torch::kCPU; + } + // Move the model to the matching device. ``torch::jit::Module::to`` is + // safe to call repeatedly. + model_.to(torch_device_); + if (comm->me == 0) + utils::logmesg(lmp, "pair_matgl/kk: model on {}\n", + torch_device_.is_cuda() ? "cuda" : "cpu"); + + // Materialize the type->Z table on device. + const int ntypes = atom->ntypes; + d_type_to_z_ = Kokkos::View("matgl:type_to_z", ntypes + 1); + auto h_type_to_z = Kokkos::create_mirror_view(d_type_to_z_); + for (int t = 0; t <= ntypes; ++t) h_type_to_z(t) = (t == 0) ? 0 : type_to_z_[t]; + Kokkos::deep_copy(d_type_to_z_, h_type_to_z); +} + +/* ---------------------------------------------------------------------- + helper: a torch::Tensor view of a Kokkos device buffer (no copy). +------------------------------------------------------------------------- */ + +namespace { + +template +torch::Tensor blob_from_view(const View &v, torch::TensorOptions opts) +{ + std::vector shape; + shape.reserve(View::rank); + for (size_t r = 0; r < View::rank; ++r) + shape.push_back(static_cast(v.extent(r))); + return torch::from_blob(v.data(), shape, opts); +} + +} // namespace + +/* ---------------------------------------------------------------------- */ + +template +void PairMATGLKokkos::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + + if (eflag_atom) + error->all(FLERR, "pair_matgl/kk does not support per-atom energies"); + if (vflag_atom) + error->all(FLERR, "pair_matgl/kk does not support per-atom virials"); + + atomKK->sync(execution_space, datamask_read); + atomKK->modified(execution_space, datamask_modify); + + using AT_ = typename AT::t_x_array; // (nall, 3) double on the device + AT_ x = atomKK->k_x.template view(); + auto f = atomKK->k_f.template view(); + auto type = atomKK->k_type.template view(); + + const int inum = list->inum; + const int nall = atom->nlocal + atom->nghost; + const int nlocal = atom->nlocal; + + auto k_list = static_cast *>(list); + auto d_ilist = k_list->d_ilist; + auto d_numneigh = k_list->d_numneigh; + auto d_neighbors = k_list->d_neighbors; + + // 1) Resize per-atom buffers. + if (nall > atom_capacity_) { + atom_capacity_ = nall; + d_atomic_numbers_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:Z"), atom_capacity_); + d_local_or_ghost_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:mask"), + atom_capacity_); + d_numneigh_short_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:nshort"), + atom_capacity_); + d_first_edge_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:first"), + atom_capacity_ + 1); + } + + // Fill Z + mask from atom type. + const auto type_to_z = d_type_to_z_; + Kokkos::parallel_for( + "matgl_kk:fill_atoms", + Kokkos::RangePolicy(0, nall), + KOKKOS_LAMBDA(const int i) { + const int t = type(i); + d_atomic_numbers_(i) = type_to_z(t); + d_local_or_ghost_(i) = (i < nlocal); + }); + + // 2) Count short-cutoff neighbors per i (only inum atoms are listed, + // so initialize numneigh_short_ for ghost atoms to zero). + Kokkos::deep_copy(d_numneigh_short_, 0); + const double r_max_sq = r_max_squared_; + + Kokkos::parallel_for( + "matgl_kk:count_neigh", + Kokkos::RangePolicy(0, inum), + KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + const double xi = x(i, 0); + const double yi = x(i, 1); + const double zi = x(i, 2); + const int jnum = d_numneigh(i); + int nshort = 0; + for (int jj = 0; jj < jnum; ++jj) { + const int j = d_neighbors(i, jj) & NEIGHMASK; + const double dx = x(j, 0) - xi; + const double dy = x(j, 1) - yi; + const double dz = x(j, 2) - zi; + const double rsq = dx * dx + dy * dy + dz * dz; + if (rsq <= r_max_sq) ++nshort; + } + d_numneigh_short_(i) = nshort; + }); + + // 3) Exclusive prefix-sum into d_first_edge_ (length nall+1). + Kokkos::parallel_scan( + "matgl_kk:scan_edges", + Kokkos::RangePolicy(0, nall + 1), + KOKKOS_LAMBDA(const int i, int &update, const bool final) { + const int v = (i < nall) ? d_numneigh_short_(i) : 0; + if (final) d_first_edge_(i) = update; + update += v; + }); + + // Fetch total edge count. + int total_edges = 0; + Kokkos::deep_copy(total_edges, Kokkos::subview(d_first_edge_, nall)); + + if (total_edges > edge_capacity_) { + edge_capacity_ = total_edges; + d_edge_index_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:edges"), 2, + edge_capacity_); + d_unit_shifts_ = Kokkos::View( + Kokkos::view_alloc(Kokkos::WithoutInitializing, "matgl:shifts"), + edge_capacity_, 3); + } + + // 4) Fill edge_index + unit_shifts. + Kokkos::parallel_for( + "matgl_kk:fill_edges", + Kokkos::RangePolicy(0, inum), + KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + const double xi = x(i, 0); + const double yi = x(i, 1); + const double zi = x(i, 2); + const int jnum = d_numneigh(i); + int e = d_first_edge_(i); + for (int jj = 0; jj < jnum; ++jj) { + const int j = d_neighbors(i, jj) & NEIGHMASK; + const double dx = x(j, 0) - xi; + const double dy = x(j, 1) - yi; + const double dz = x(j, 2) - zi; + const double rsq = dx * dx + dy * dy + dz * dz; + if (rsq > r_max_sq) continue; + d_edge_index_(0, e) = i; + d_edge_index_(1, e) = j; + d_unit_shifts_(e, 0) = 0; + d_unit_shifts_(e, 1) = 0; + d_unit_shifts_(e, 2) = 0; + ++e; + } + }); + + // 5) Build positions tensor on the same device as the model. + // LAMMPS' k_x is (nall,3) double; if the model is float32 we cast. + const auto torch_real_opts = torch::TensorOptions().dtype(dtype_).device(torch_device_); + const auto torch_long_opts = torch::TensorOptions().dtype(torch::kInt64).device(torch_device_); + const auto torch_bool_opts = torch::TensorOptions().dtype(torch::kBool).device(torch_device_); + + // x is (nall,3) double already on `DeviceType`. We make a libtorch view + // through from_blob and cast to the model's dtype if needed. + torch::Tensor positions_d = torch::from_blob( + x.data(), {nall, 3}, torch::TensorOptions().dtype(torch::kFloat64).device(torch_device_)); + torch::Tensor positions = (dtype_ == torch::kFloat64) + ? positions_d.clone() + : positions_d.to(dtype_); + + torch::Tensor edge_index = blob_from_view(d_edge_index_, torch_long_opts); + // The Kokkos View is 2 x edge_capacity_ but we only filled 0..total_edges. + edge_index = edge_index.narrow(/*dim=*/1, /*start=*/0, /*length=*/total_edges); + + torch::Tensor unit_shifts = blob_from_view(d_unit_shifts_, torch_long_opts); + unit_shifts = unit_shifts.narrow(0, 0, total_edges); + + torch::Tensor atomic_numbers = blob_from_view(d_atomic_numbers_, torch_long_opts); + atomic_numbers = atomic_numbers.narrow(0, 0, nall); + torch::Tensor local_or_ghost = blob_from_view(d_local_or_ghost_, torch_bool_opts); + local_or_ghost = local_or_ghost.narrow(0, 0, nall); + + // 6) Cell. + torch::Tensor cell = torch::zeros({3, 3}, torch_real_opts); + { + auto host_opts = torch::TensorOptions().dtype(dtype_).device(torch::kCPU); + auto cell_h = torch::zeros({3, 3}, host_opts); + if (dtype_ == torch::kFloat64) { + auto c = cell_h.accessor(); + c[0][0] = domain->xprd; + c[1][0] = domain->xy; c[1][1] = domain->yprd; + c[2][0] = domain->xz; c[2][1] = domain->yz; c[2][2] = domain->zprd; + } else { + auto c = cell_h.accessor(); + c[0][0] = static_cast(domain->xprd); + c[1][0] = static_cast(domain->xy); + c[1][1] = static_cast(domain->yprd); + c[2][0] = static_cast(domain->xz); + c[2][1] = static_cast(domain->yz); + c[2][2] = static_cast(domain->zprd); + } + cell.copy_(cell_h, /*non_blocking=*/false); + } + + // 7) Forward. + std::vector inputs; + inputs.reserve(7); + inputs.emplace_back(positions); + inputs.emplace_back(edge_index); + inputs.emplace_back(unit_shifts); + inputs.emplace_back(cell); + inputs.emplace_back(atomic_numbers); + inputs.emplace_back(local_or_ghost); + inputs.emplace_back(static_cast(vflag_global)); + + torch::IValue result; + try { + result = model_.forward(inputs); + } catch (const std::exception &e) { + error->all(FLERR, "pair_matgl/kk: model forward failed: {}", e.what()); + } + auto out = result.toGenericDict(); + + // 8) Energy + force scatter. LAMMPS f is double on device; the model may + // return float32 — promote on the fly. + double total_energy = out.at("total_energy_local").toTensor().to(torch::kFloat64).item(); + if (eflag_global) eng_vdwl += total_energy; + + torch::Tensor forces_t = + out.at("forces").toTensor().to(torch::kFloat64).contiguous(); + + // Wrap the force tensor as a Kokkos device-side unmanaged view and add + // into LAMMPS' f. + using UnmanagedF = Kokkos::View>; + UnmanagedF d_force_in(forces_t.data_ptr(), nall, 3); + + Kokkos::parallel_for( + "matgl_kk:add_forces", + Kokkos::RangePolicy(0, nall), + KOKKOS_LAMBDA(const int i) { + f(i, 0) += d_force_in(i, 0); + f(i, 1) += d_force_in(i, 1); + f(i, 2) += d_force_in(i, 2); + }); + + // 9) Virial — the model returns a small 3x3 tensor; pull to host. + if (vflag_global) { + auto vir_t = out.at("virials").toTensor().to(torch::kFloat64).cpu(); + auto va = vir_t.accessor(); + virial[0] += va[0][0]; + virial[1] += va[1][1]; + virial[2] += va[2][2]; + virial[3] += 0.5 * (va[0][1] + va[1][0]); + virial[4] += 0.5 * (va[0][2] + va[2][0]); + virial[5] += 0.5 * (va[1][2] + va[2][1]); + } +} + +/* ---------------------------------------------------------------------- */ + +namespace LAMMPS_NS { +template class PairMATGLKokkos; +#ifdef LMP_KOKKOS_GPU +template class PairMATGLKokkos; +#endif +} // namespace LAMMPS_NS diff --git a/lammps/src/KOKKOS/pair_matgl_kokkos.h b/lammps/src/KOKKOS/pair_matgl_kokkos.h new file mode 100644 index 00000000..90ed0495 --- /dev/null +++ b/lammps/src/KOKKOS/pair_matgl_kokkos.h @@ -0,0 +1,75 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + + pair_matgl/kk: Kokkos variant of pair_matgl. + + The model body still runs inside LibTorch (CUDA/HIP) — Kokkos here is the + glue that builds the edge / position / type tensors on the GPU and + accumulates forces back into LAMMPS' Kokkos-managed views without going + through host memory. + + See ACEsuit/lammps:src/KOKKOS/pair_mace_kokkos.{cpp,h} for the reference + implementation we mirror. +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(matgl/kk, PairMATGLKokkos); +PairStyle(matgl/kk/device, PairMATGLKokkos); +PairStyle(matgl/kk/host, PairMATGLKokkos); +// clang-format on +#else + +#ifndef LMP_PAIR_MATGL_KOKKOS_H +#define LMP_PAIR_MATGL_KOKKOS_H + +#include "kokkos_type.h" +#include "neigh_list_kokkos.h" +#include "pair_kokkos.h" +#include "pair_matgl.h" + +#include +#include + +namespace LAMMPS_NS { + +template +class PairMATGLKokkos : public PairMATGL { + public: + using device_type = DeviceType; + using AT = ArrayTypes; + + PairMATGLKokkos(class LAMMPS *); + ~PairMATGLKokkos() override; + + void compute(int, int) override; + void init_style() override; + + protected: + // Pinned to the CUDA device the libtorch model is on. Picked once at + // init_style() from `lmp -k on g 1`. + torch::Device torch_device_ = torch::kCPU; + + // Persistent device-side buffers — re-allocated on size change. + Kokkos::View d_edge_index_; // (2, E) + Kokkos::View d_unit_shifts_; // (E, 3) + Kokkos::View d_atomic_numbers_; // (N,) + Kokkos::View d_local_or_ghost_; // (N,) + + // Edge-counting scratch. + Kokkos::View d_numneigh_short_; + Kokkos::View d_first_edge_; + + // For converting LAMMPS atom-type (1..ntypes) -> Z on device. + Kokkos::View d_type_to_z_; + + // Capacity tracking so we only resize on growth. + int64_t edge_capacity_ = 0; + int64_t atom_capacity_ = 0; +}; + +} // namespace LAMMPS_NS + +#endif +#endif diff --git a/lammps/src/ML-MATGL/pair_matgl.cpp b/lammps/src/ML-MATGL/pair_matgl.cpp new file mode 100644 index 00000000..b6f12bc0 --- /dev/null +++ b/lammps/src/ML-MATGL/pair_matgl.cpp @@ -0,0 +1,428 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + + pair_matgl: TorchScript bridge to a MatGL TensorNet PES. + + Mirrors ACEsuit/lammps:src/ML-MACE/pair_mace.cpp's overall control flow, + adapted to the LAMMPSMatGLModel forward signature defined in + matgl/src/matgl/ext/_lammps.py: + + forward(positions, edge_index, unit_shifts, cell, atomic_numbers, + local_or_ghost, compute_virials) + -> {total_energy_local, node_energy, forces, virials} + + The model owns the autograd machinery; this pair style is a thin shim + that translates LAMMPS data structures into the tensors that + forward expects, then accumulates the returned forces and virials back + into LAMMPS arrays. + + Required LAMMPS commands (also documented in lammps/README.md): + atom_modify map yes + newton on + pair_style matgl + pair_coeff * * ... +------------------------------------------------------------------------- */ + +#include "pair_matgl.h" + +#include "atom.h" +#include "comm.h" +#include "domain.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neigh_list.h" +#include "neigh_request.h" +#include "neighbor.h" +#include "potential_file_reader.h" +#include "tokenizer.h" +#include "update.h" + +#include +#include +#include + +using namespace LAMMPS_NS; + +namespace { + +constexpr const char *kPeriodicTable[] = { + "X", "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", + "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", + "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", + "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", + "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", + "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", + "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", + "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og", +}; +constexpr int kNumElements = sizeof(kPeriodicTable) / sizeof(kPeriodicTable[0]); + +int symbol_to_z(const std::string &sym) +{ + for (int z = 1; z < kNumElements; ++z) { + if (sym == kPeriodicTable[z]) return z; + } + return -1; +} + +} // namespace + +/* ---------------------------------------------------------------------- */ + +PairMATGL::PairMATGL(LAMMPS *lmp) : Pair(lmp) +{ + single_enable = 0; // no per-pair energy decomposition + restartinfo = 0; // model lives on disk, not in restart files + one_coeff = 1; // a single pair_coeff line covers all type pairs + manybody_flag = 1; // GNN: not pairwise additive + centroidstressflag = CENTROID_NOTAVAIL; + no_virial_fdotr_compute = 1; // we set the virial directly from the model + unit_convert_flag = 0; +} + +/* ---------------------------------------------------------------------- */ + +PairMATGL::~PairMATGL() +{ + if (allocated) { + memory->destroy(setflag); + memory->destroy(cutsq); + } +} + +/* ---------------------------------------------------------------------- + global settings: optional `no_domain_decomposition` keyword +------------------------------------------------------------------------- */ + +void PairMATGL::settings(int narg, char **arg) +{ + no_domain_decomposition_ = false; + for (int i = 0; i < narg; ++i) { + if (std::strcmp(arg[i], "no_domain_decomposition") == 0) { + no_domain_decomposition_ = true; + } else { + error->all(FLERR, "Illegal pair_style matgl option: {}", arg[i]); + } + } +} + +/* ---------------------------------------------------------------------- + parse `pair_coeff * * S1 S2 ...` +------------------------------------------------------------------------- */ + +void PairMATGL::coeff(int narg, char **arg) +{ + if (!allocated) allocate(); + + const int ntypes = atom->ntypes; + if (narg != 3 + ntypes) + error->all(FLERR, + "pair_coeff matgl expects: * * {} species names " + "(one per LAMMPS atom type, in order)", + ntypes); + + if (std::strcmp(arg[0], "*") != 0 || std::strcmp(arg[1], "*") != 0) + error->all(FLERR, "pair_coeff matgl requires both type indices to be '*'"); + + const std::string model_path = arg[2]; + + // Try-catch around torch::jit::load: if libtorch can't read the file or the + // module's forward signature isn't ours, surface a useful error early. + try { + model_ = torch::jit::load(model_path); + } catch (const std::exception &e) { + error->all(FLERR, "Could not load TorchScript model '{}': {}", model_path, e.what()); + } + model_.eval(); + + // Pull r_max and dtype out of the scripted module's named buffers. The + // Python wrapper guarantees these are present. + bool found_r_max = false; + bool found_dtype_probe = false; + for (const auto &b : model_.named_buffers()) { + if (b.name == "z_to_index") { + // dtype probe: the model's compute dtype matches its non-integer + // buffers (data_mean, data_std, element_refs). z_to_index is int64. + // We pick element_refs below. + continue; + } + if (b.name == "element_refs") { + dtype_ = b.value.scalar_type(); + found_dtype_probe = true; + } + } + // Read r_max from a Python attribute (compiled to an IValue). + try { + auto attr = model_.attr("r_max"); + if (attr.isDouble()) { + r_max_ = attr.toDouble(); + } else if (attr.isInt()) { + r_max_ = static_cast(attr.toInt()); + } + found_r_max = (r_max_ > 0.0); + } catch (const std::exception &) { + // fall through; we'll error below + } + + if (!found_r_max) + error->all(FLERR, + "TorchScript model is missing the `r_max` attribute — was it produced " + "by `mgl create-lammps-model`?"); + if (!found_dtype_probe) + error->all(FLERR, + "TorchScript model is missing the `element_refs` buffer — was it " + "produced by `mgl create-lammps-model`?"); + + r_max_squared_ = r_max_ * r_max_; + + // Map LAMMPS atom-type index (1-based) -> atomic number Z. + type_to_z_.assign(ntypes + 1, 0); + for (int t = 1; t <= ntypes; ++t) { + const std::string sym = arg[2 + t]; + const int z = ::symbol_to_z(sym); + if (z < 0) error->all(FLERR, "pair_matgl: unknown species symbol '{}'", sym); + type_to_z_[t] = z; + } + + for (int i = 1; i <= ntypes; ++i) + for (int j = i; j <= ntypes; ++j) setflag[i][j] = 1; + + if (comm->me == 0) + utils::logmesg(lmp, + "pair_matgl: loaded {} (r_max={:.4f} Å, dtype={})\n", + model_path, + r_max_, + (dtype_ == torch::kFloat64 ? "float64" : "float32")); +} + +/* ---------------------------------------------------------------------- */ + +void PairMATGL::allocate() +{ + allocated = 1; + const int n = atom->ntypes + 1; + + memory->create(setflag, n, n, "pair:setflag"); + for (int i = 1; i < n; ++i) + for (int j = i; j < n; ++j) setflag[i][j] = 0; + + memory->create(cutsq, n, n, "pair:cutsq"); +} + +/* ---------------------------------------------------------------------- */ + +double PairMATGL::init_one(int /*i*/, int /*j*/) +{ + // All type pairs share the model cutoff. + return r_max_; +} + +/* ---------------------------------------------------------------------- */ + +void PairMATGL::init_style() +{ + if (atom->tag_enable == 0) + error->all(FLERR, "pair_style matgl requires atom-IDs"); + if (force->newton_pair == 0) + error->all(FLERR, "pair_style matgl requires `newton on`"); + if (atom->map_style == Atom::MAP_NONE) + error->all(FLERR, + "pair_style matgl requires `atom_modify map yes` so neighbor " + "lookups can resolve ghost atoms"); + + // Full neighbor list with ghost atoms — the Python wrapper expects + // edge_index to point at the same `positions` table used for both owned + // and ghost atoms. + neighbor->add_request(this, NeighConst::REQ_FULL | NeighConst::REQ_GHOST); +} + +/* ---------------------------------------------------------------------- + the heart of the pair style: build edge tensors, run the model, + accumulate forces and the virial +------------------------------------------------------------------------- */ + +void PairMATGL::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + + if (eflag_atom) + error->all(FLERR, "pair_matgl does not support per-atom energies"); + if (vflag_atom) + error->all(FLERR, "pair_matgl does not support per-atom virials"); + + const int inum = list->inum; + const int *const ilist = list->ilist; + const int *const numneigh = list->numneigh; + int **firstneigh = list->firstneigh; + + const int nlocal = atom->nlocal; + const int nall = nlocal + atom->nghost; + const double *const *const x = atom->x; + double *const *const f = atom->f; + const int *const type = atom->type; + + // 1) Allocate Cartesian / atomic-number / mask buffers sized to nall. + auto opts_real = torch::TensorOptions().dtype(dtype_); + auto opts_long = torch::TensorOptions().dtype(torch::kInt64); + auto opts_bool = torch::TensorOptions().dtype(torch::kBool); + + torch::Tensor positions = torch::empty({nall, 3}, opts_real); + torch::Tensor atomic_numbers = torch::empty({nall}, opts_long); + torch::Tensor local_or_ghost = torch::empty({nall}, opts_bool); + + // Fill them. Promote to the model's dtype on the fly. + if (dtype_ == torch::kFloat64) { + auto pos_a = positions.accessor(); + for (int i = 0; i < nall; ++i) { + pos_a[i][0] = x[i][0]; + pos_a[i][1] = x[i][1]; + pos_a[i][2] = x[i][2]; + } + } else { + auto pos_a = positions.accessor(); + for (int i = 0; i < nall; ++i) { + pos_a[i][0] = static_cast(x[i][0]); + pos_a[i][1] = static_cast(x[i][1]); + pos_a[i][2] = static_cast(x[i][2]); + } + } + { + auto z_a = atomic_numbers.accessor(); + auto m_a = local_or_ghost.accessor(); + for (int i = 0; i < nall; ++i) { + z_a[i] = type_to_z_[type[i]]; + m_a[i] = (i < nlocal); + } + } + + // 2) Walk the neighbor list, filter by r_max_squared, build edge_index + + // unit_shifts. Ghost positions are already wrapped+imaged by LAMMPS, + // so we recover the integer image triple from the positional offset + // relative to the owned image of the atom. + std::vector edge_src; + std::vector edge_dst; + std::vector edge_shifts; // flat (E*3,) + edge_src.reserve(nall * 32); + edge_dst.reserve(nall * 32); + edge_shifts.reserve(nall * 32 * 3); + + for (int ii = 0; ii < inum; ++ii) { + const int i = ilist[ii]; + const double xi = x[i][0]; + const double yi = x[i][1]; + const double zi = x[i][2]; + const int *const jlist = firstneigh[i]; + const int jnum = numneigh[i]; + + for (int jj = 0; jj < jnum; ++jj) { + const int j = jlist[jj] & NEIGHMASK; + const double dx = x[j][0] - xi; + const double dy = x[j][1] - yi; + const double dz = x[j][2] - zi; + const double rsq = dx * dx + dy * dy + dz * dz; + if (rsq > r_max_squared_) continue; + + // unit_shifts: integer image vector (nx,ny,nz) such that + // x[j] == x_owned[j_local] + (nx,ny,nz) @ cell + // For LAMMPS' "i and ghost j" pattern we leave shifts at zero and + // let the Python wrapper compute pbc_offshift = unit_shifts @ cell; + // it always evaluates to zero because LAMMPS hands us already-imaged + // ghost positions. The wrapper gradient with respect to the strain + // tensor still propagates correctly because the cell appears via + // pos_s = positions @ (I + strain) + // applied to BOTH local and ghost positions. + edge_src.push_back(static_cast(i)); + edge_dst.push_back(static_cast(j)); + edge_shifts.push_back(0); + edge_shifts.push_back(0); + edge_shifts.push_back(0); + } + } + + const int64_t E = static_cast(edge_src.size()); + torch::Tensor edge_index = + torch::empty({2, E}, opts_long); + torch::Tensor unit_shifts = torch::empty({E, 3}, opts_long); + if (E > 0) { + auto ei_a = edge_index.accessor(); + auto us_a = unit_shifts.accessor(); + for (int64_t e = 0; e < E; ++e) { + ei_a[0][e] = edge_src[e]; + ei_a[1][e] = edge_dst[e]; + us_a[e][0] = edge_shifts[3 * e + 0]; + us_a[e][1] = edge_shifts[3 * e + 1]; + us_a[e][2] = edge_shifts[3 * e + 2]; + } + } + + // 3) Cell (row-vector basis). LAMMPS stores h_inv etc.; we build the cell + // from boxlo/boxhi/xy/xz/yz. + torch::Tensor cell = torch::zeros({3, 3}, opts_real); + { + const double xprd = domain->xprd; + const double yprd = domain->yprd; + const double zprd = domain->zprd; + const double xy = domain->xy; + const double xz = domain->xz; + const double yz = domain->yz; + if (dtype_ == torch::kFloat64) { + auto c = cell.accessor(); + c[0][0] = xprd; + c[1][0] = xy; c[1][1] = yprd; + c[2][0] = xz; c[2][1] = yz; c[2][2] = zprd; + } else { + auto c = cell.accessor(); + c[0][0] = static_cast(xprd); + c[1][0] = static_cast(xy); c[1][1] = static_cast(yprd); + c[2][0] = static_cast(xz); c[2][1] = static_cast(yz); + c[2][2] = static_cast(zprd); + } + } + + // 4) Run the scripted forward. + std::vector inputs; + inputs.reserve(7); + inputs.emplace_back(positions); + inputs.emplace_back(edge_index); + inputs.emplace_back(unit_shifts); + inputs.emplace_back(cell); + inputs.emplace_back(atomic_numbers); + inputs.emplace_back(local_or_ghost); + inputs.emplace_back(static_cast(vflag_global)); + + torch::IValue result; + try { + result = model_.forward(inputs); + } catch (const std::exception &e) { + error->all(FLERR, "pair_matgl: model forward failed: {}", e.what()); + } + auto out = result.toGenericDict(); + + // 5) Read scalars + forces back into LAMMPS arrays. + auto total_energy = out.at("total_energy_local").toTensor().to(torch::kFloat64).item(); + if (eflag_global) eng_vdwl += total_energy; + + torch::Tensor forces_t = out.at("forces").toTensor().to(torch::kFloat64); + auto fa = forces_t.accessor(); + for (int i = 0; i < nall; ++i) { + f[i][0] += fa[i][0]; + f[i][1] += fa[i][1]; + f[i][2] += fa[i][2]; + } + + // 6) Virial — the model returns a 3x3 tensor with the LAMMPS sign + // convention (V_ij = sum r_i F_j). LAMMPS stores 6 Voigt components + // in `virial`: xx, yy, zz, xy, xz, yz. + if (vflag_global) { + auto vir_t = out.at("virials").toTensor().to(torch::kFloat64); + auto va = vir_t.accessor(); + virial[0] += va[0][0]; + virial[1] += va[1][1]; + virial[2] += va[2][2]; + virial[3] += 0.5 * (va[0][1] + va[1][0]); + virial[4] += 0.5 * (va[0][2] + va[2][0]); + virial[5] += 0.5 * (va[1][2] + va[2][1]); + } +} diff --git a/lammps/src/ML-MATGL/pair_matgl.h b/lammps/src/ML-MATGL/pair_matgl.h new file mode 100644 index 00000000..34fac2ad --- /dev/null +++ b/lammps/src/ML-MATGL/pair_matgl.h @@ -0,0 +1,67 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + + pair_matgl: serial / OpenMP pair style backed by a TorchScripted MatGL + ``LAMMPSMatGLModel`` (see matgl/src/matgl/ext/_lammps.py). The Python + side ships a ``mgl create-lammps-model`` CLI that produces the .pt file + this pair style consumes. +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(matgl, PairMATGL); +// clang-format on +#else + +#ifndef LMP_PAIR_MATGL_H +#define LMP_PAIR_MATGL_H + +#include "pair.h" + +#include +#include + +#include +#include + +namespace LAMMPS_NS { + +class PairMATGL : public Pair { + public: + PairMATGL(class LAMMPS *); + ~PairMATGL() override; + + void compute(int, int) override; + void settings(int, char **) override; + void coeff(int, char **) override; + void init_style() override; + double init_one(int, int) override; + + protected: + void allocate(); + + // The TorchScript model produced by `mgl create-lammps-model`. + torch::jit::Module model_; + + // Cutoff radius baked into the model (read from the `r_max` buffer). + double r_max_ = 0.0; + double r_max_squared_ = 0.0; + + // Tensor dtype the model was exported with: torch::kFloat32 or torch::kFloat64. + torch::Dtype dtype_ = torch::kFloat32; + + // Atomic numbers per LAMMPS atom-type (1-based; index 0 unused). Built from + // the species names in the pair_coeff line. + std::vector type_to_z_; + + // Whether the user disabled MPI domain decomposition with the optional + // ``no_domain_decomposition`` keyword on the pair_style line. Tracks the + // MACE-LAMMPS flag of the same name; reserved for future use. + bool no_domain_decomposition_ = false; +}; + +} // namespace LAMMPS_NS + +#endif +#endif diff --git a/lammps/tests/in.matgl_si b/lammps/tests/in.matgl_si new file mode 100644 index 00000000..6085a2d8 --- /dev/null +++ b/lammps/tests/in.matgl_si @@ -0,0 +1,50 @@ +# in.matgl_si +# +# Single-point energy/force/stress on a 4-atom Mo-S unit cell using a +# pair_matgl-loaded TorchScript model. Compare ``log.lammps`` energies +# against the Python reference produced by ``python_reference.py`` in the +# same directory. +# +# Usage: +# /build/lmp -in in.matgl_si +# +# Requires: +# - LAMMPS built with PKG_ML-MATGL=ON (see ../README.md). +# - A TorchScript artifact at ./model.pt produced via: +# mgl create-lammps-model -m -o lammps/tests/model.pt --dtype float32 + +units metal +atom_style atomic +boundary p p p + +atom_modify map yes +newton on + +# 4-atom Mo-S cell matching tests/ext/test_lammps_export.py::mo_s_supercell +lattice custom 1.0 a1 4.5 0.0 0.0 a2 0.0 4.5 0.0 a3 0.0 0.0 4.5 & + origin 0 0 0 & + basis 0.00 0.00 0.00 & + basis 0.50 0.50 0.50 & + basis 0.50 0.00 0.25 & + basis 0.00 0.50 0.75 +region box prism 0 1 0 1 0 1 0 0 0 units lattice +create_box 2 box +create_atoms 1 box basis 1 1 basis 2 2 basis 3 1 basis 4 2 + +# Mass values are ignored by pair_matgl but LAMMPS still requires them. +mass 1 95.95 # Mo +mass 2 32.06 # S + +pair_style matgl +pair_coeff * * model.pt Mo S + +# `fx`/`fy`/`fz` aren't valid thermo keywords — dump per-atom forces with +# `dump` if you want them, the parity check only needs the scalar `pe`. +thermo_style custom step pe pxx pyy pzz pyz pxz pxy +thermo 1 + +dump forces all custom 1 forces.dump id type x y z fx fy fz +dump_modify forces sort id + +# Single-point evaluation. +run 0 diff --git a/lammps/tests/python_reference.py b/lammps/tests/python_reference.py new file mode 100644 index 00000000..cc188d91 --- /dev/null +++ b/lammps/tests/python_reference.py @@ -0,0 +1,79 @@ +"""Reference energies/forces/stresses for `in.matgl_si`. + +Run this *after* exporting the model with `mgl create-lammps-model` and +*before* running LAMMPS so you have a gold standard to diff against. + + cd lammps/tests + uv run mgl create-lammps-model -m -o model.pt --dtype float32 + uv run python python_reference.py + /build/lmp -in in.matgl_si + +Both runs use the same 4-atom Mo-S supercell at fixed atomic positions, so +results should match within ~1e-5 eV (energy) / 1e-4 eV/Å (forces) / +1e-3 GPa (stress). +""" + +from __future__ import annotations + +import argparse +import sys + +import numpy as np +from pymatgen.core import Lattice, Structure + +import matgl +from matgl.ext.ase import PESCalculator + + +def _build_structure() -> Structure: + return Structure( + Lattice.cubic(4.5), + ["Mo", "S", "Mo", "S"], + [ + [0.00, 0.00, 0.00], + [0.50, 0.50, 0.50], + [0.50, 0.00, 0.25], + [0.00, 0.50, 0.75], + ], + ) + + +def main() -> int: + """Print reference energies/forces/stresses for the in.matgl_si test deck.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-m", + "--model", + required=True, + help="MatGL model identifier (HF Hub repo id or local path) — must " + "be the same one passed to mgl create-lammps-model.", + ) + args = parser.parse_args() + + structure = _build_structure() + pot = matgl.load_model(args.model) + pot.eval() + + calc = PESCalculator(potential=pot, stress_unit="GPa", use_voigt=False) + from pymatgen.io.ase import AseAtomsAdaptor + + atoms = AseAtomsAdaptor().get_atoms(structure) + atoms.calc = calc + + energy = atoms.get_potential_energy() + forces = atoms.get_forces() + stress = atoms.get_stress() # 6-vector in Voigt order (xx,yy,zz,yz,xz,xy) + + print(f"# python reference for `in.matgl_si` (model = {args.model})") + print(f"energy_eV = {energy:.10e}") + print("forces_eV_per_A =") + for row in forces: + print(f" {row[0]:.10e} {row[1]:.10e} {row[2]:.10e}") + if isinstance(stress, np.ndarray) and stress.size == 6: + print("stress_GPa (Voigt: xx yy zz yz xz xy) =") + print(" " + " ".join(f"{s:.6e}" for s in stress)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml index cf9411e7..70b97947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,14 @@ ignore_missing_imports = true module = ["matgl.kernels.*"] ignore_errors = true +[[tool.mypy.overrides]] +# ``register_buffer`` returns ``Tensor | Module`` in mypy's view; indexing, +# iterating, or attribute-accessing those triggers spurious errors. +# Annotating the buffers at class scope to fix it breaks TorchScript's +# annotation resolver, so silence the affected codes. +module = ["matgl.ext._lammps"] +disable_error_code = ["index", "arg-type", "operator", "union-attr"] + [tool.coverage.run] relative_files = true omit = [ diff --git a/src/matgl/cli.py b/src/matgl/cli.py index 2fd0f90b..341e821b 100644 --- a/src/matgl/cli.py +++ b/src/matgl/cli.py @@ -198,6 +198,54 @@ def clear_cache(args: argparse.Namespace) -> None: matgl.clear_cache(not args.yes) +def create_lammps_model(args: argparse.Namespace) -> int: + """Export a MatGL Potential as a LAMMPS-loadable TorchScript artifact. + + Loads the named/local model, wraps it in :class:`LAMMPSMatGLModel`, runs + ``torch.jit.script``, and writes the result to ``--outfile``. The artifact + is consumed by the ``pair_matgl`` and ``pair_matgl/kokkos`` LAMMPS pair + styles via ``torch::jit::load``. + + Args: + args: Parsed CLI arguments — ``model``, ``outfile``, ``dtype``, + ``device``, ``no_script``. + + Returns: + ``0`` on success, ``1`` if the underlying potential is unsupported. + """ + # Lazy import keeps the CLI responsive when this subcommand isn't used and + # avoids dragging the export-only deps onto the import path. + from matgl.ext._lammps import LAMMPSMatGLModel + + dtype_map = {"float32": torch.float32, "float64": torch.float64} + dtype = dtype_map[args.dtype] + + logger.info("Loading model %s ...", args.model) + potential = _load_potential(args.model) + potential.eval() + + if args.device != "cpu": + potential.to(args.device) + + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) # type:ignore[arg-type] + wrapper.eval() + + if args.no_script: + torch.save(wrapper, args.outfile) + print(f"Wrote eager wrapper (NOT TorchScript-compiled) to {args.outfile}") + else: + scripted = torch.jit.script(wrapper) + scripted.save(args.outfile) + print(f"Wrote scripted LAMMPS-MatGL artifact to {args.outfile}") + + print(" r_max :", wrapper.r_max) + print(" n_species :", wrapper.n_species) + print(" dtype :", args.dtype) + species = list(potential.model.element_types) # type:ignore[union-attr,arg-type] + print(" species :", species[: wrapper.n_species]) + return 0 + + def main(): """Handle main.""" parser = argparse.ArgumentParser( @@ -460,6 +508,47 @@ def main(): p_clear.set_defaults(func=clear_cache) + # LAMMPS export + p_lammps = subparsers.add_parser( + "create-lammps-model", + help="Export a MatGL Potential as a TorchScript artifact loadable by pair_matgl[/kokkos].", + ) + p_lammps.add_argument( + "-m", + "--model", + dest="model", + required=True, + help="Path or name of a saved MatGL model (TensorNet PyG, extensive PES).", + ) + p_lammps.add_argument( + "-o", + "--outfile", + dest="outfile", + required=True, + help="Output path for the LAMMPS-loadable artifact (e.g. matgl_model.pt).", + ) + p_lammps.add_argument( + "--dtype", + dest="dtype", + choices=["float32", "float64"], + default="float32", + help="Wrapper buffer dtype. Match what your LAMMPS LibTorch was built with.", + ) + p_lammps.add_argument( + "--device", + dest="device", + default="cpu", + help="Device to load weights onto before export (cpu | cuda[:N]).", + ) + p_lammps.add_argument( + "--no-script", + dest="no_script", + action="store_true", + help="Save the eager wrapper instead of running torch.jit.script. " + "Only useful for debugging — not loadable from LAMMPS C++.", + ) + p_lammps.set_defaults(func=create_lammps_model) + args = parser.parse_args() return args.func(args) diff --git a/src/matgl/ext/_lammps.py b/src/matgl/ext/_lammps.py new file mode 100644 index 00000000..e013cf48 --- /dev/null +++ b/src/matgl/ext/_lammps.py @@ -0,0 +1,547 @@ +"""LAMMPS-compatible TorchScript wrapper for MatGL Potentials. + +Mirrors the MACE pattern (``mace.cli.create_lammps_model``): produces a +``torch.jit.ScriptModule`` that takes plain tensors and returns a dict of +energy / per-atom energy / forces / virials, ready to be loaded by a LAMMPS +``pair_style`` via ``torch::jit::load``. + +Supported architectures (PyG backend only): + + * ``TensorNet`` (extensive head, ``use_warp=False``) + * ``M3GNet`` (extensive head) + +Other matgl models (CHGNet, MEGNet, SO3Net, QET) need follow-up work — see +the LAMMPS plugin README in the matgl repo for status. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from pymatgen.core.periodic_table import Element +from torch import Tensor, nn +from torch.autograd import grad + +from matgl.graph._compute_pyg import ( + compute_pair_vector_and_distance, + compute_theta_and_phi, + create_line_graph_torch, +) +from matgl.layers._basis import spherical_bessel_smooth +from matgl.utils.cutoff import polynomial_cutoff +from matgl.utils.maths import decompose_tensor, tensor_norm + +if TYPE_CHECKING: + from matgl.apps._pes_pyg import Potential + +logger = logging.getLogger(__name__) + + +_MAX_Z = 119 + + +def _build_z_to_index(element_types: tuple[str, ...]) -> Tensor: + """Build a length-_MAX_Z lookup buffer mapping atomic number -> internal index.""" + z_to_index = torch.full((_MAX_Z,), -1, dtype=torch.long) + for idx, sym in enumerate(element_types): + z = int(Element(sym).Z) + z_to_index[z] = idx + return z_to_index + + +def _y_l0_torch(cos_theta: Tensor, max_l: int) -> Tensor: + """Pure-tensor port of ``SphericalHarmonicsFunction(use_phi=False)``. + + Returns ``Y_l^0(cos_theta) for l in [0, max_l)`` via the Legendre-polynomial + recurrence ``P_l = ((2l-1)x P_{l-1} - (l-1) P_{l-2}) / l`` and the + real-spherical-harmonic normalization ``sqrt((2l+1)/(4π))``. Matches the + sympy-based reference implementation to fp32 precision. + """ + pi = 3.141592653589793 + n = int(cos_theta.size(0)) + out = torch.empty((n, max_l), dtype=cos_theta.dtype, device=cos_theta.device) + if max_l >= 1: + out[:, 0] = 0.5 * (1.0 / pi) ** 0.5 * torch.ones_like(cos_theta) + if max_l >= 2: + out[:, 1] = 0.5 * (3.0 / pi) ** 0.5 * cos_theta + pim2 = torch.ones_like(cos_theta) + pim1 = cos_theta + for lv in range(2, max_l): + pl = ((2 * lv - 1) * cos_theta * pim1 - (lv - 1) * pim2) / lv + out[:, lv] = ((2 * lv + 1) / (4 * pi)) ** 0.5 * pl + pim2 = pim1 + pim1 = pl + return out + + +def _m3gnet_three_body_basis_torch( + triple_bond_lengths: Tensor, + cos_theta: Tensor, + max_n: int, + max_l: int, + cutoff: float, +) -> Tensor: + """Pure-tensor port of M3GNet's ``SphericalBesselWithHarmonics`` (use_smooth=True, use_phi=False). + + Combines ``spherical_bessel_smooth(r, cutoff, max_n*max_l)`` with the + Legendre Y_l^0(cos_theta) basis using the ``combine_sbf_shf`` recipe for + ``use_phi=False``: each Y_l value is repeated ``max_n`` times to align with + the SBF column blocks, then multiplied element-wise. Output shape + ``(num_triples, max_n*max_l)`` matches the reference. + """ + sbf = spherical_bessel_smooth(triple_bond_lengths, cutoff=cutoff, max_n=max_n * max_l) + shf = _y_l0_torch(cos_theta, max_l) + expanded_shf = shf.repeat_interleave(max_n, dim=1) + return (sbf * expanded_shf).reshape(-1, max_n * max_l) + + +class _SmoothSBFExpansion(nn.Module): + """TorchScript-friendly stand-in for ``BondExpansion(rbf_type='SphericalBessel', smooth=True)``. + + ``SphericalBesselFunction``'s forward dispatches through ``@torch.jit.ignore`` + helpers (sympy-lambdified ``funcs`` list), which blocks ``torch.jit.save``. + The smooth-SBF basis it computes is mathematically identical to + :func:`matgl.layers._basis.spherical_bessel_smooth`, which is implemented in + pure tensor ops. This adapter is a drop-in replacement preserving the + ``(N, max_n)`` output shape. + """ + + # NOTE: no class-level annotations — under ``from __future__ import + # annotations`` TorchScript's annotation resolver treats them as + # unresolved string types and errors with "Unknown type annotation". + # ``__constants__`` makes scripting bake ``cutoff`` and ``max_n`` in as + # Python literals (their values never change after __init__). + + __constants__ = ["cutoff", "max_n"] # noqa: RUF012 — TorchScript convention. + + def __init__(self, cutoff: float, max_n: int) -> None: + super().__init__() + self.cutoff = float(cutoff) + self.max_n = int(max_n) + + def forward(self, r: Tensor) -> Tensor: + return spherical_bessel_smooth(r, cutoff=self.cutoff, max_n=self.max_n) + + +class _TensorNetKernel(nn.Module): + """Per-atom raw-energy compute for TensorNet (PyG, no-Warp).""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + # Swap SphericalBessel bond_expansion for the pure-tensor equivalent so + # the kernel survives torch.jit.script.save. Gaussian / ExpNormal / + # RadialBessel basis modules are already TorchScript-friendly and pass + # through unchanged. + # ``model.bond_expansion`` typed as ``Tensor | Module`` via nn.Module's + # ``__getattr__``; the cast narrows it for mypy without runtime cost. + from typing import cast + + be = cast("nn.Module", model.bond_expansion) + if getattr(be, "rbf_type", None) == "SphericalBessel": + if not bool(be.rbf.smooth): + raise NotImplementedError( + "LAMMPS export of TensorNet currently requires use_smooth=True " + "for rbf_type='SphericalBessel' (non-smooth SphericalBesselFunction " + "uses sympy-lambdified callables incompatible with torch.jit.script). " + "Re-load the checkpoint with smooth=True." + ) + self.bond_expansion: nn.Module = _SmoothSBFExpansion( + cutoff=float(be.cutoff), + max_n=int(be.rbf.max_n), + ) + else: + self.bond_expansion = be + self.tensor_embedding = model.tensor_embedding + self.layers = model.layers + self.out_norm = model.out_norm + self.linear = model.linear + self.final_layer = model.final_layer + + def forward( + self, + atom_types: Tensor, + edge_index: Tensor, + pbc_offshift: Tensor, + pos_s: Tensor, + num_nodes: int, + ) -> Tensor: + """Returns per-atom raw energies (pre-std/mean, pre-element-ref).""" + bond_vec, bond_dist = compute_pair_vector_and_distance(pos_s, edge_index, pbc_offshift) + edge_attr = self.bond_expansion(bond_dist) + + x_tensor, _ = self.tensor_embedding(atom_types, edge_index, edge_attr, bond_dist, bond_vec, None) + for layer in self.layers: + x_tensor = layer(edge_index, bond_dist, edge_attr, x_tensor) + + scalars, skew, traceless = decompose_tensor(x_tensor) + x = torch.cat( + (tensor_norm(scalars), tensor_norm(skew), tensor_norm(traceless)), + dim=-1, + ) + x = self.out_norm(x) + x = self.linear(x) + return self.final_layer(x).view(-1) + + +class _M3GNetKernel(nn.Module): + """Per-atom raw-energy compute for M3GNet (PyG, extensive). + + Drops ``model.basis_expansion`` (``SphericalBesselWithHarmonics``) entirely + — its ``sbf`` / ``shf`` submodules carry sympy-lambdified Python lists that + don't survive ``torch.jit.script``. Instead we recompute the three-body + basis with :func:`_m3gnet_three_body_basis_torch`, which handles M3GNet's + one combination (``use_smooth=True``, ``use_phi=False``) in pure tensor + ops. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + # Skip ``bond_expansion`` and ``basis_expansion`` — both wrap + # ``SphericalBesselFunction``, which uses sympy-lambdified Python + # callables that don't survive ``torch.jit.script.save``. We + # recompute their outputs in pure tensor ops below. + self.embedding = model.embedding + self.three_body_interactions = model.three_body_interactions + self.graph_layers = model.graph_layers + self.final_layer = model.final_layer + self.threebody_cutoff: float = float(model.threebody_cutoff) + self.n_blocks: int = int(model.n_blocks) + self.basis_max_n: int = int(model.basis_expansion.max_n) + self.basis_max_l: int = int(model.basis_expansion.max_l) + self.basis_cutoff: float = float(model.cutoff) + # bond_expansion's RBF — stored separately because both + # SphericalBesselFunction (for SphericalBessel rbf_type) and the + # alternatives need different reconstruction. + self.bond_max_n: int = int(model.bond_expansion.rbf.max_n) + # Validate the supported configuration: + if model.bond_expansion.rbf_type != "SphericalBessel": + raise NotImplementedError( + "LAMMPS export of M3GNet currently requires " + "rbf_type='SphericalBessel'; got " + f"{model.bond_expansion.rbf_type!r}." + ) + if not bool(model.bond_expansion.rbf.smooth): + raise NotImplementedError( + "LAMMPS export of M3GNet currently requires use_smooth=True for the bond expansion." + ) + if not bool(model.basis_expansion.use_smooth): + raise NotImplementedError( + "LAMMPS export of M3GNet currently requires use_smooth=True " + "for the three-body basis. Re-train or re-load with " + "use_smooth=True." + ) + if bool(model.basis_expansion.use_phi): + raise NotImplementedError( + "LAMMPS export of M3GNet currently requires use_phi=False (matches the standard PES configuration)." + ) + + def forward( + self, + atom_types: Tensor, + edge_index: Tensor, + pbc_offshift: Tensor, + pos_s: Tensor, + num_nodes: int, + ) -> Tensor: + """Returns per-atom raw energies (pre-std/mean, pre-element-ref).""" + bond_vec, bond_dist = compute_pair_vector_and_distance(pos_s, edge_index, pbc_offshift) + # Tensor-only smooth-SBF (replaces ``self.bond_expansion`` for the + # ``rbf_type='SphericalBessel'`` + ``smooth=True`` configuration). + expanded_dists = spherical_bessel_smooth(bond_dist, cutoff=self.basis_cutoff, max_n=self.bond_max_n) + + # Line graph (3-body): tensor-only build; no PyG Data, no numpy. + l_g = create_line_graph_torch(edge_index, bond_dist, bond_vec, num_nodes, self.threebody_cutoff) + angles = compute_theta_and_phi(l_g["bond_vec"], l_g["bond_dist"], l_g["line_edge_index"]) + + # Tensor-only spherical-Bessel x spherical-harmonic basis. + three_body_basis = _m3gnet_three_body_basis_torch( + angles["triple_bond_lengths"], + angles["cos_theta"], + self.basis_max_n, + self.basis_max_l, + self.basis_cutoff, + ) + + three_body_cutoff = polynomial_cutoff(bond_dist, self.threebody_cutoff) + + node_feat, edge_feat, state_feat = self.embedding(atom_types, expanded_dists, None) + + edge_dst_atom = edge_index[1] + line_edge_index = l_g["line_edge_index"] + n_triple_ij = l_g["n_triple_ij"] + num_bonds = int(edge_index.size(1)) + + # TorchScript can't index ModuleList with a non-literal int, so we + # iterate over the two lists in parallel via zip(...). ``strict=True`` + # would be safer but TorchScript doesn't accept the kwarg here, and + # ``three_body_interactions`` and ``graph_layers`` are constructed + # together in M3GNet's ``__init__`` so they're always the same length. + for tbi, gl in zip(self.three_body_interactions, self.graph_layers): # noqa: B905 + edge_feat = tbi( + edge_dst_atom, + line_edge_index, + n_triple_ij, + num_bonds, + three_body_basis, + three_body_cutoff, + node_feat, + edge_feat, + ) + edge_feat, node_feat, state_feat = gl( + edge_index, + edge_feat, + node_feat, + state_feat, + expanded_dists, + None, # node_batch — single graph + None, # edge_batch + num_nodes, + 1, # num_graphs + ) + + return self.final_layer(node_feat).view(-1) + + +class LAMMPSMatGLModel(nn.Module): + """TorchScript-friendly wrapper around a MatGL ``Potential`` for LAMMPS. + + Takes plain tensors (Cartesian positions, edge_index, integer image shifts, + cell, atomic numbers, ghost mask) and returns a dict of total local energy, + per-atom node energies, forces, and virials. Replicates the autograd + machinery from :class:`matgl.apps.pes.Potential` but driven by Cartesian + inputs so LAMMPS doesn't have to materialize fractional coordinates. + + Architecture-specific feature compute lives in a small inner ``kernel`` + module (``_TensorNetKernel`` or ``_M3GNetKernel``), so the strain / autograd + machinery here is single-source. + + Limitations: + * Inner model must be ``TensorNet`` or ``M3GNet`` (PyG backend, + extensive head, no-Warp for TensorNet). + * Per-atom virials and the Hessian path are not exported. + * ``data_mean`` is added once to ``total_energy_local``; multi-rank + LAMMPS therefore requires ``data_mean == 0`` for correctness. The + standard MatGL PES checkpoints satisfy this — element offsets are + carried by ``element_refs`` instead. + """ + + # NOTE: deliberately *no* class-level Tensor annotations on the registered + # buffers — they trip TorchScript's annotation resolver (it sees them as + # unresolved string annotations under ``from __future__ import annotations`` + # and fails with "Unknown type annotation"). Buffer types are still + # inferred correctly from ``register_buffer`` calls. mypy access to + # ``self.data_mean`` etc. is narrowed locally with ``cast`` where needed. + + def __init__( + self, + potential: Potential, + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__() + + # Imports kept local so the public ``matgl.ext`` namespace doesn't + # hard-require these submodules at import time. + from matgl.models._m3gnet_pyg import M3GNet + from matgl.models._tensornet_pyg import TensorNet + + model = potential.model + + if isinstance(model, TensorNet): + if getattr(model, "_use_warp", False): + raise ValueError( + "Inner TensorNet was constructed with use_warp=True. Re-load " + "or rebuild it with use_warp=False before exporting for " + "LAMMPS — Warp custom-ops are not TorchScript-compatible." + ) + self.kernel: nn.Module = _TensorNetKernel(model) + elif isinstance(model, M3GNet): + self.kernel = _M3GNetKernel(model) + else: + raise NotImplementedError( + "LAMMPSMatGLModel currently supports TensorNet (PyG) and " + f"M3GNet (PyG); got {type(model).__name__}. CHGNet/MEGNet/" + "SO3Net/QET are not yet exported." + ) + + if model.is_intensive: + raise ValueError( + f"Inner {type(model).__name__} has is_intensive=True. " + "LAMMPS expects an extensive PES; only is_intensive=False " + "is supported." + ) + + if potential.calc_repuls: + raise NotImplementedError("ZBL repulsion (calc_repuls=True) export is not yet supported.") + + if potential.calc_magmom or potential.calc_charge: + raise NotImplementedError("Magmom / charge heads are not exported for LAMMPS.") + + # Bake in the cutoff and group as Python attrs (immutable in script). + self.cutoff: float = float(model.cutoff) + self.r_max: float = float(model.cutoff) + self.n_species: int = len(model.element_types) + + # Buffers carried over from Potential. ``Potential`` stores these + # via ``register_buffer``, so mypy sees them as ``Tensor | Module``; + # the ``cast`` keeps the typing tight without any runtime cost. + from typing import cast + + data_mean = cast("Tensor", potential.data_mean).detach().to(dtype) + data_std = cast("Tensor", potential.data_std).detach().to(dtype) + self.register_buffer("data_mean", data_mean) + self.register_buffer("data_std", data_std) + + # Per-element reference energies (1D, indexed by internal element idx). + if potential.element_refs is not None: + ref = potential.element_refs.property_offset.detach().to(dtype) + if ref.dim() != 1: + raise NotImplementedError("State-conditional element_refs (>1D) not supported.") + if ref.numel() != self.n_species: + if ref.numel() < self.n_species: + ref = torch.cat([ref, torch.zeros(self.n_species - ref.numel(), dtype=dtype)]) + else: + ref = ref[: self.n_species] + else: + ref = torch.zeros(self.n_species, dtype=dtype) + self.register_buffer("element_refs", ref) + + # Z -> internal element index lookup. C++ passes Z; we translate. + self.register_buffer("z_to_index", _build_z_to_index(model.element_types)) + + # Atomic numbers in element_types order — useful for inspection. + atomic_numbers = torch.tensor([int(Element(s).Z) for s in model.element_types], dtype=torch.long) + self.register_buffer("atomic_numbers", atomic_numbers) + + if abs(float(data_mean.item() if data_mean.ndim == 0 else 0.0)) > 1e-8: + logger.warning( + "Exported model has non-zero data_mean=%s. Multi-rank LAMMPS " + "runs will double-count this offset; use single-rank Kokkos.", + float(data_mean.item()) if data_mean.ndim == 0 else data_mean, + ) + + # Cast the wrapper to the target dtype. The inner model is already the + # checkpoint dtype; the conversion above ensures buffers match. + self.to(dtype) + + def forward( + self, + positions: Tensor, + edge_index: Tensor, + unit_shifts: Tensor, + cell: Tensor, + atomic_numbers: Tensor, + local_or_ghost: Tensor, + compute_virials: bool, + ) -> dict[str, Tensor]: + """Energy / forces / virials for a single LAMMPS configuration. + + Args: + positions: Cartesian coordinates, shape (N, 3). N includes + ghost atoms when running in domain-decomposed mode. + edge_index: COO edges, shape (2, E), int64. + unit_shifts: Integer image vectors per edge, shape (E, 3), + int64. The destination atom's effective position is + ``positions[dst] + unit_shifts @ cell``. + cell: Lattice basis as row vectors, shape (3, 3). + atomic_numbers: Per-atom Z, shape (N,), int64. + local_or_ghost: True for owned atoms, False for ghosts; shape + (N,), bool. Only owned atoms contribute to the energy sum. + compute_virials: Whether to compute the virial tensor. + + Returns: + dict with keys ``total_energy_local`` (scalar), ``node_energy`` + (N,), ``forces`` (N, 3), and ``virials`` (3, 3). + """ + atom_types = self.z_to_index[atomic_numbers] + + strain = torch.zeros((3, 3), dtype=positions.dtype, device=positions.device) + if compute_virials: + strain.requires_grad_(True) + + eye = torch.eye(3, dtype=positions.dtype, device=positions.device) + deformation = eye + strain + pos_s = positions @ deformation + cell_s = cell @ deformation + pos_s.requires_grad_(True) + + pbc_offshift = unit_shifts.to(positions.dtype) @ cell_s + + num_nodes = int(positions.size(0)) + atomic_energies_raw = self.kernel(atom_types, edge_index, pbc_offshift, pos_s, num_nodes) + + node_energy = self.data_std * atomic_energies_raw + self.element_refs[atom_types] + masked = node_energy.masked_fill(~local_or_ghost, 0.0) + total_energy_local = masked.sum() + self.data_mean + + # TorchScript demands the explicit Optional[Tensor] typing on grad_outputs. + grad_outputs: list[torch.Tensor | None] = [torch.ones_like(total_energy_local)] + + if compute_virials: + grads = grad( + outputs=[total_energy_local], + inputs=[pos_s, strain], + grad_outputs=grad_outputs, + create_graph=False, + retain_graph=False, + ) + pos_grad = grads[0] + strain_grad = grads[1] + forces = -pos_grad if pos_grad is not None else torch.zeros_like(positions) + virials = ( + strain_grad + if strain_grad is not None + else torch.zeros((3, 3), dtype=positions.dtype, device=positions.device) + ) + else: + grads = grad( + outputs=[total_energy_local], + inputs=[pos_s], + grad_outputs=grad_outputs, + create_graph=False, + retain_graph=False, + ) + pos_grad = grads[0] + forces = -pos_grad if pos_grad is not None else torch.zeros_like(positions) + virials = torch.zeros((3, 3), dtype=positions.dtype, device=positions.device) + + return { + "total_energy_local": total_energy_local, + "node_energy": node_energy, + "forces": forces, + "virials": virials, + } + + +def export_lammps_model( + potential: Potential, + output_path: str, + dtype: torch.dtype = torch.float32, + script: bool = True, +) -> LAMMPSMatGLModel: + """Wrap a :class:`Potential` and save a LAMMPS-loadable artifact. + + Args: + potential: A trained MatGL ``Potential`` (TensorNet or M3GNet on the + PyG backend, extensive head). + output_path: Where to write the ``.pt`` file. The C++ pair_style + loads this with ``torch::jit::load``. + dtype: Wrapper buffer dtype. Forces ``float32`` or ``float64``. + script: When True (default), runs ``torch.jit.script`` and saves the + script module. When False, saves the eager wrapper as a regular + PyTorch checkpoint (useful for debugging — not loadable from C++). + + Returns: + The wrapper instance (eager). + """ + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) + wrapper.eval() + + if script: + scripted = torch.jit.script(wrapper) + scripted.save(output_path) + else: + torch.save(wrapper, output_path) + + return wrapper diff --git a/src/matgl/ext/lammps.py b/src/matgl/ext/lammps.py new file mode 100644 index 00000000..85301251 --- /dev/null +++ b/src/matgl/ext/lammps.py @@ -0,0 +1,11 @@ +"""LAMMPS interface for MatGL. + +Exports a TorchScript-friendly wrapper that LAMMPS pair styles +(``pair_matgl``, ``pair_matgl/kokkos``) load via ``torch::jit::load``. +""" + +from __future__ import annotations + +from ._lammps import LAMMPSMatGLModel, export_lammps_model + +__all__ = ["LAMMPSMatGLModel", "export_lammps_model"] diff --git a/src/matgl/graph/_compute_pyg.py b/src/matgl/graph/_compute_pyg.py index 86c935a6..ff9ecad3 100644 --- a/src/matgl/graph/_compute_pyg.py +++ b/src/matgl/graph/_compute_pyg.py @@ -165,6 +165,96 @@ def _compute_3body_indices( return line_edge_index, n_triple_ij, max_bond_id +def _compute_3body_indices_torch(edge_index: torch.Tensor, num_nodes: int) -> tuple[torch.Tensor, torch.Tensor, int]: + """Tensor-only port of :func:`_compute_3body_indices`. + + Same semantics as the numpy/Python-loop original but expressed in pure + tensor ops so it can be inlined inside a ``torch.jit.script``-ed forward + (the LAMMPS export wrapper needs this). Assumes ``edge_index[0]`` is + sorted ascending — matches what + :func:`pymatgen.optimization.neighbors.find_points_in_spheres` and the + LAMMPS neighbor-list walk produce. + + Args: + edge_index: ``(2, E)`` parent edge indices. + num_nodes: Number of atoms in the parent graph. + + Returns: + Same ``(line_edge_index, n_triple_ij, max_bond_id)`` triple as the + legacy implementation. + """ + src = edge_index[0] + E = int(src.size(0)) + device = src.device + if E == 0: + empty_2 = torch.zeros((2, 0), dtype=torch.long, device=device) + empty_1 = torch.zeros((0,), dtype=torch.long, device=device) + return empty_2, empty_1, 0 + + src_long = src.to(torch.long) + n_b = torch.bincount(src_long, minlength=num_nodes) + cum = torch.zeros(num_nodes + 1, dtype=torch.long, device=device) + cum[1:] = n_b.cumsum(0) + + n_triple_ij = (n_b[src_long] - 1).clamp(min=0).to(torch.long) + total = int(n_triple_ij.sum().item()) + if total == 0: + empty_2 = torch.zeros((2, 0), dtype=torch.long, device=device) + return empty_2, n_triple_ij, 0 + + src_bond = torch.repeat_interleave(torch.arange(E, dtype=torch.long, device=device), n_triple_ij) + cum_triples = torch.zeros(E + 1, dtype=torch.long, device=device) + cum_triples[1:] = n_triple_ij.cumsum(0) + j_within = torch.arange(total, dtype=torch.long, device=device) - cum_triples[src_bond] + + bond_local = torch.arange(E, dtype=torch.long, device=device) - cum[src_long] + bond_local_per = bond_local[src_bond] + j_actual = j_within + (j_within >= bond_local_per).to(torch.long) + + src_atom = src_long[src_bond] + dst_bond = cum[src_atom] + j_actual + + line_edge_index = torch.stack([src_bond, dst_bond], dim=0) + max_bond_id = int(line_edge_index.max().item()) + 1 + return line_edge_index, n_triple_ij, max_bond_id + + +def create_line_graph_torch( + edge_index: torch.Tensor, + bond_dist: torch.Tensor, + bond_vec: torch.Tensor, + num_nodes: int, + threebody_cutoff: float, +) -> dict[str, torch.Tensor]: + """TorchScript-friendly variant of :func:`create_line_graph`. + + Drops the optional ``pbc_offset`` and the ``Callable`` argument that + ``prune_edges_by_features`` uses; otherwise produces the same line-graph + bundle keys M3GNet's PyG forward consumes. + + Args: + edge_index: ``(2, E)`` parent edges (sorted ascending by source atom). + bond_dist: Per-edge distances of the parent graph. + bond_vec: Per-edge bond vectors of the parent graph. + num_nodes: Number of atoms in the parent graph. + threebody_cutoff: Distance cutoff used to drop edges before forming + three-body terms. + """ + valid = bond_dist <= threebody_cutoff + pruned_edge_index = edge_index[:, valid] + pruned_bond_dist = bond_dist[valid] + pruned_bond_vec = bond_vec[valid] + + line_edge_index, n_triple_ij, max_bond_id = _compute_3body_indices_torch(pruned_edge_index, num_nodes) + + return { + "bond_dist": pruned_bond_dist[:max_bond_id], + "bond_vec": pruned_bond_vec[:max_bond_id], + "line_edge_index": line_edge_index, + "n_triple_ij": n_triple_ij[:max_bond_id], + } + + def create_line_graph( edge_index: torch.Tensor, bond_dist: torch.Tensor, diff --git a/src/matgl/layers/_basis.py b/src/matgl/layers/_basis.py index ea61e80b..8f05dc8d 100644 --- a/src/matgl/layers/_basis.py +++ b/src/matgl/layers/_basis.py @@ -157,10 +157,12 @@ def forward(self, r: torch.Tensor) -> torch.Tensor: return self._call_smooth_sbf(r) return self._call_sbf(r) + @torch.jit.ignore def _call_smooth_sbf(self, r): results = [i(r) for i in self.funcs] return torch.t(torch.stack(results)) + @torch.jit.ignore def _call_sbf(self, r): # ``r`` is per-edge distance. The non-smooth spherical Bessel basis is # j_l(root_{l,n} * r / cutoff) * sqrt(2/cutoff^3) / |j_{l+1}(root_{l,n})| @@ -351,17 +353,22 @@ def spherical_bessel_smooth(r: Tensor, cutoff: float = 5.0, max_n: int = 10) -> Returns: expanded spherical harmonics with derivatives smooth at boundary """ + # ``pi`` and ``sqrt(2.0)`` declared locally so this function survives + # ``torch.jit.script`` (closed-over module globals are not resolvable + # in TorchScript's annotation context). + pi_local = 3.141592653589793 + sqrt2 = 1.4142135623730951 n = torch.arange(max_n).type(dtype=matgl.float_th)[None, :] r = r[:, None] fnr = ( (-1) ** n - * sqrt(2.0) - * pi + * sqrt2 + * pi_local / cutoff**1.5 * (n + 1) * (n + 2) / torch.sqrt(2 * n**2 + 6 * n + 5) - * (_sinc(r * (n + 1) * pi / cutoff) + _sinc(r * (n + 2) * pi / cutoff)) + * (_sinc(r * (n + 1) * pi_local / cutoff) + _sinc(r * (n + 2) * pi_local / cutoff)) ) en = n**2 * (n + 2) ** 2 / (4 * (n + 1) ** 4 + 1) dn = [torch.tensor(1.0)] diff --git a/src/matgl/layers/_core.py b/src/matgl/layers/_core.py index 09e57185..8d572e4f 100644 --- a/src/matgl/layers/_core.py +++ b/src/matgl/layers/_core.py @@ -74,6 +74,7 @@ def __repr__(self) -> str: return f"MLP({', '.join(dims)})" @property + @torch.jit.unused def last_linear(self) -> Linear: """Return the last linear layer in the network.""" for layer in reversed(self.layers): @@ -83,11 +84,13 @@ def last_linear(self) -> Linear: raise RuntimeError(msg) @property + @torch.jit.unused def depth(self) -> int: """Returns depth of MLP.""" return self._depth @property + @torch.jit.unused def in_features(self) -> int: """Return input features of MLP.""" first_layer = self.layers[0] @@ -95,6 +98,7 @@ def in_features(self) -> int: return first_layer.in_features @property + @torch.jit.unused def out_features(self) -> int: """Returns output features of MLP.""" for layer in reversed(self.layers): diff --git a/src/matgl/layers/_embedding.py b/src/matgl/layers/_embedding.py index 57bbcf77..a34151a9 100644 --- a/src/matgl/layers/_embedding.py +++ b/src/matgl/layers/_embedding.py @@ -73,7 +73,12 @@ def __init__( dim_edges = [degree_rbf, dim_edge_embedding] self.layer_edge_embedding = MLP(dim_edges, activation=activation, activate_last=True) - def forward(self, node_attr, edge_attr, state_attr): + def forward( + self, + node_attr: torch.Tensor, + edge_attr: torch.Tensor, + state_attr: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Output embedded features. Args: @@ -94,14 +99,14 @@ def forward(self, node_attr, edge_attr, state_attr): edge_feat = self.layer_edge_embedding(edge_attr.to(matgl.float_th)) else: edge_feat = edge_attr + state_feat: torch.Tensor | None = None if self.include_state is True: - if self.ntypes_state and self.dim_state_embedding is not None: + assert state_attr is not None, "state_attr must be provided when include_state=True" + if self.ntypes_state is not None and self.dim_state_embedding is not None: state_feat = self.layer_state_embedding(state_attr) elif self.dim_state_feats is not None: state_attr = torch.unsqueeze(state_attr, 0) state_feat = self.layer_state_embedding(state_attr.to(matgl.float_th)) else: state_feat = state_attr - else: - state_feat = None return node_feat, edge_feat, state_feat diff --git a/src/matgl/layers/_embedding_pyg.py b/src/matgl/layers/_embedding_pyg.py index bf9191d5..2d653bf9 100644 --- a/src/matgl/layers/_embedding_pyg.py +++ b/src/matgl/layers/_embedding_pyg.py @@ -84,7 +84,16 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def message(self, x_i, x_j, edge_attr, edge_weight, Iij, Aij, Sij): + def message( + self, + x_i: torch.Tensor, + x_j: torch.Tensor, + edge_attr: torch.Tensor, + edge_weight: torch.Tensor, + Iij: torch.Tensor, + Aij: torch.Tensor, + Sij: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Message function for edge updates.""" vi = x_i # Source node features vj = x_j # Destination node features @@ -94,13 +103,19 @@ def message(self, x_i, x_j, edge_attr, edge_weight, Iij, Aij, Sij): scalars = Zij[..., None, None] * Iij skew_matrices = Zij[..., None, None] * Aij traceless_tensors = Zij[..., None, None] * Sij - return {"I": scalars, "A": skew_matrices, "S": traceless_tensors} + return scalars, skew_matrices, traceless_tensors - def aggregate(self, msg, index, dim_size=None): + def aggregate( + self, + msg: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + index: torch.Tensor, + dim_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Aggregate messages for node updates.""" - scalars = scatter_add(msg["I"], index, dim_size=dim_size) - skew_matrices = scatter_add(msg["A"], index, dim_size=dim_size) - traceless_tensors = scatter_add(msg["S"], index, dim_size=dim_size) + scalars_msg, skew_msg, traceless_msg = msg + scalars = scatter_add(scalars_msg, index, dim_size=dim_size) + skew_matrices = scatter_add(skew_msg, index, dim_size=dim_size) + traceless_tensors = scatter_add(traceless_msg, index, dim_size=dim_size) return scalars, skew_matrices, traceless_tensors def forward( @@ -110,8 +125,8 @@ def forward( edge_attr: torch.Tensor, edge_weight: torch.Tensor, edge_vec: torch.Tensor, - state_attr=None, - ): + state_attr: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Compute embedded node tensors and (optional) state features. Args: diff --git a/src/matgl/layers/_graph_convolution_pyg.py b/src/matgl/layers/_graph_convolution_pyg.py index b3c59400..06b75c30 100644 --- a/src/matgl/layers/_graph_convolution_pyg.py +++ b/src/matgl/layers/_graph_convolution_pyg.py @@ -160,9 +160,16 @@ def forward( dX = scalars + skew_metrices + traceless_tensors return X + dX + torch.matmul(dX, dX) - def message(self, edge_index, x_I: torch.Tensor, x_A: torch.Tensor, x_S: torch.Tensor, edge_attr: torch.Tensor): + def message( + self, + edge_index: torch.Tensor, + x_I: torch.Tensor, + x_A: torch.Tensor, + x_S: torch.Tensor, + edge_attr: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute messages for each edge.""" - _, dst = edge_index + dst = edge_index[1] x_I_j = x_I[dst] x_A_j = x_A[dst] x_S_j = x_S[dst] @@ -171,7 +178,12 @@ def message(self, edge_index, x_I: torch.Tensor, x_A: torch.Tensor, x_S: torch.T ) return scalars, skew_metrices, traceless_tensors - def aggregate(self, inputs, index, dim_size): + def aggregate( + self, + inputs: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + index: torch.Tensor, + dim_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Aggregate messages for node updates.""" scalars, skew_matrices, traceless_tensors = inputs scalars_agg = scatter_add(scalars, index, dim_size=dim_size) @@ -385,7 +397,7 @@ def forward( edge_index, edge_feat, node_feat, state_feat, node_batch, edge_batch, num_nodes, num_graphs ) - if self.dropout: + if self.dropout is not None: edge_feat = self.dropout(edge_feat) node_feat = self.dropout(node_feat) state_feat = self.dropout(state_feat) @@ -505,7 +517,12 @@ def state_update_( """Compute the state update (Eq. 6) using per-graph mean of node features.""" uv = _per_graph_mean(node_feat, node_batch, num_graphs) inputs = torch.hstack([state_feat, uv]) - return self.state_update_func(inputs) # type: ignore[misc] + # Narrow ``state_update_func`` from Optional[Module] to Module for + # TorchScript — the caller already gates on ``include_state`` so the + # assert is informational only. + func = self.state_update_func + assert func is not None + return func(inputs) def forward( self, @@ -598,7 +615,7 @@ def forward( edge_feat, node_feat, state_feat = self.conv( edge_index, edge_feat, node_feat, state_feat, rbf, node_batch, edge_batch, num_nodes, num_graphs ) - if self.dropout: + if self.dropout is not None: edge_feat = self.dropout(edge_feat) node_feat = self.dropout(node_feat) if state_feat is not None: diff --git a/src/matgl/utils/cutoff.py b/src/matgl/utils/cutoff.py index 84121c9b..6ffcc1a1 100644 --- a/src/matgl/utils/cutoff.py +++ b/src/matgl/utils/cutoff.py @@ -2,8 +2,6 @@ from __future__ import annotations -from math import pi - import torch @@ -40,4 +38,4 @@ def cosine_cutoff(r: torch.Tensor, cutoff: float) -> torch.Tensor: Returns: cosine cutoff function """ - return torch.where(r <= cutoff, 0.5 * (torch.cos(pi * r / cutoff) + 1), 0.0) + return torch.where(r <= cutoff, 0.5 * (torch.cos(torch.pi * r / cutoff) + 1), 0.0) diff --git a/tests/ext/test_lammps_export.py b/tests/ext/test_lammps_export.py new file mode 100644 index 00000000..019bc153 --- /dev/null +++ b/tests/ext/test_lammps_export.py @@ -0,0 +1,272 @@ +"""Tests for the LAMMPS TorchScript export wrapper. + +The wrapper has to give the same energy/forces/stresses as +``Potential.forward`` for any periodic configuration. We exercise this with +a small randomly-initialized TensorNet on a Mo-S supercell, comparing the +wrapper's Cartesian-driven path to the canonical PyG ``Data``-driven path. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch + +import matgl + +if matgl.config.BACKEND != "PYG": + pytest.skip("LAMMPS export only supports PyG backend", allow_module_level=True) + +from pymatgen.core import Lattice, Structure +from pymatgen.optimization.neighbors import find_points_in_spheres + +from matgl.apps._pes_pyg import Potential +from matgl.ext._lammps import LAMMPSMatGLModel +from matgl.ext._pymatgen_pyg import Structure2Graph +from matgl.models._m3gnet_pyg import M3GNet +from matgl.models._tensornet_pyg import TensorNet + + +def _build_lammps_inputs(structure: Structure, element_types: tuple[str, ...], cutoff: float, dtype: torch.dtype): + """Build the tensor inputs LAMMPS would produce for a single configuration. + + Mirrors the CPU-fallback path in ``Atoms2Graph.get_graph`` — pymatgen's + ``find_points_in_spheres`` produces (src, dst, image, dist), which is + exactly the (edge_index, unit_shifts) pair LAMMPS gives us at the C++ + boundary. + """ + lattice = np.array(structure.lattice.matrix) + cart = structure.cart_coords + src, dst, images, dist = find_points_in_spheres( + cart, + cart, + r=float(cutoff), + pbc=np.array([1, 1, 1], dtype=np.int64), + lattice=lattice, + tol=1.0e-8, + ) + keep = (src != dst) | (dist > 1e-8) + src = src[keep] + dst = dst[keep] + images = images[keep] + + edge_index = torch.tensor(np.stack([src, dst]), dtype=torch.long) + unit_shifts = torch.tensor(images, dtype=torch.long) + positions = torch.tensor(cart, dtype=dtype) + cell = torch.tensor(lattice, dtype=dtype) + + z_per_atom = torch.tensor([site.specie.Z for site in structure], dtype=torch.long) + local_or_ghost = torch.ones(len(structure), dtype=torch.bool) + return positions, edge_index, unit_shifts, cell, z_per_atom, local_or_ghost + + +def _build_tensornet_potential(rbf_type: str = "Gaussian") -> tuple[Potential, tuple[str, ...]]: + torch.manual_seed(0) + element_types = ("Mo", "S") + model = TensorNet( + element_types=element_types, + is_intensive=False, + units=16, + nblocks=1, + num_rbf=8, + cutoff=4.0, + use_warp=False, + rbf_type=rbf_type, + # SphericalBessel checkpoints (e.g. TensorNet-MatPES-r2SCAN) use the + # smooth basis; the non-smooth path relies on sympy-lambdified funcs + # that don't survive torch.jit.save. + use_smooth=rbf_type == "SphericalBessel", + ) + refs = torch.tensor([-1.5, -2.25], dtype=matgl.float_th) + pot = Potential( + model=model, + data_mean=0.0, + data_std=1.0, + element_refs=refs, + calc_forces=True, + calc_stresses=True, + ) + pot.eval() + return pot, element_types + + +def _build_m3gnet_potential() -> tuple[Potential, tuple[str, ...]]: + torch.manual_seed(0) + element_types = ("Mo", "S") + model = M3GNet( + element_types=element_types, + is_intensive=False, + cutoff=4.0, + threebody_cutoff=3.0, + dim_node_embedding=16, + dim_edge_embedding=16, + n_blocks=1, + max_n=3, + max_l=3, + units=16, + rbf_type="SphericalBessel", + use_smooth=True, + ) + refs = torch.tensor([-1.5, -2.25], dtype=matgl.float_th) + pot = Potential( + model=model, + data_mean=0.0, + data_std=1.0, + element_refs=refs, + calc_forces=True, + calc_stresses=True, + ) + pot.eval() + return pot, element_types + + +@pytest.fixture(params=["tensornet_gaussian", "tensornet_sb", "m3gnet"]) +def tiny_potential(request): + """Tiny deterministic Potential, parametrized over supported architectures. + + ``tensornet_sb`` exercises the SphericalBessel-smooth bond expansion path, + which matches the pretrained ``TensorNet-MatPES-r2SCAN`` checkpoint and + requires the LAMMPS kernel's pure-tensor SBF adapter. + """ + if request.param == "tensornet_gaussian": + return _build_tensornet_potential(rbf_type="Gaussian") + if request.param == "tensornet_sb": + return _build_tensornet_potential(rbf_type="SphericalBessel") + return _build_m3gnet_potential() + + +@pytest.fixture +def mo_s_supercell(): + return Structure( + Lattice.cubic(4.5), + ["Mo", "S", "Mo", "S"], + [ + [0.00, 0.00, 0.00], + [0.50, 0.50, 0.50], + [0.50, 0.00, 0.25], + [0.00, 0.50, 0.75], + ], + ) + + +def _potential_reference(potential, structure, element_types, cutoff, dtype): + """Run ``Potential.forward`` on a structure via Structure2Graph, in eval().""" + s2g = Structure2Graph(element_types=element_types, cutoff=cutoff) + g, lat, _ = s2g.get_graph(structure) + g.frac_coords = g.frac_coords.to(dtype) + lat = lat.to(dtype) + energy, forces, stresses, _ = potential(g, lat, None) + return energy, forces, stresses + + +def test_eager_parity_against_potential(tiny_potential, mo_s_supercell): + """Wrapper's energy/forces match Potential.forward for a periodic crystal.""" + potential, element_types = tiny_potential + structure = mo_s_supercell + cutoff = 4.0 + dtype = matgl.float_th + + # Reference path. + e_ref, f_ref, _s_ref = _potential_reference(potential, structure, element_types, cutoff, dtype) + + # Wrapper path. + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) + wrapper.eval() + + pos, eidx, ushifts, cell, z, local = _build_lammps_inputs(structure, element_types, cutoff, dtype) + out = wrapper(pos, eidx, ushifts, cell, z, local, compute_virials=True) + + e_wrap = out["total_energy_local"] + f_wrap = out["forces"] + + # Energy parity (scalar). + assert torch.allclose(e_wrap.detach(), e_ref.detach().reshape_as(e_wrap), atol=1e-5, rtol=1e-5), ( + f"energy mismatch: wrapper={e_wrap.item()} ref={e_ref.item()}" + ) + + # Force parity (per-atom). + assert torch.allclose(f_wrap.detach(), f_ref.detach(), atol=1e-4, rtol=1e-4), ( + f"max force diff = {(f_wrap.detach() - f_ref.detach()).abs().max().item()}" + ) + + +def test_ghost_mask_partitions_energy(tiny_potential, mo_s_supercell): + """Splitting atoms into 'owned' vs 'ghost' must sum back to the full energy.""" + potential, element_types = tiny_potential + structure = mo_s_supercell + cutoff = 4.0 + dtype = matgl.float_th + + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) + wrapper.eval() + + pos, eidx, ushifts, cell, z, _ = _build_lammps_inputs(structure, element_types, cutoff, dtype) + + n = pos.shape[0] + half = n // 2 + + mask_a = torch.zeros(n, dtype=torch.bool) + mask_a[:half] = True + mask_b = ~mask_a + + out_a = wrapper(pos, eidx, ushifts, cell, z, mask_a, compute_virials=False) + out_b = wrapper(pos, eidx, ushifts, cell, z, mask_b, compute_virials=False) + out_full = wrapper(pos, eidx, ushifts, cell, z, torch.ones(n, dtype=torch.bool), compute_virials=False) + + e_split = out_a["total_energy_local"] + out_b["total_energy_local"] + # data_mean is added once per call, so the split version adds it twice. We + # constructed tiny_potential with data_mean=0, so this is consistent. + assert torch.allclose(e_split.detach(), out_full["total_energy_local"].detach(), atol=1e-5, rtol=1e-5) + + +def test_torchscript_round_trip(tiny_potential, mo_s_supercell, tmp_path): + """Scripted module saves, reloads, and matches eager outputs to fp precision.""" + potential, element_types = tiny_potential + structure = mo_s_supercell + cutoff = 4.0 + dtype = torch.float32 + + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) + wrapper.eval() + + scripted = torch.jit.script(wrapper) + artifact = tmp_path / "model.pt" + scripted.save(str(artifact)) + reloaded = torch.jit.load(str(artifact)) + + pos, eidx, ushifts, cell, z, local = _build_lammps_inputs(structure, element_types, cutoff, dtype) + + out_eager = wrapper(pos, eidx, ushifts, cell, z, local, True) + out_script = reloaded(pos, eidx, ushifts, cell, z, local, True) + + for key in ("total_energy_local", "node_energy", "forces", "virials"): + diff = (out_eager[key].detach() - out_script[key].detach()).abs().max().item() + assert diff < 1e-5, f"{key} max abs diff = {diff}" + + +def test_virials_match_stress_volume(tiny_potential, mo_s_supercell): + """Wrapper virials = Potential stresses * volume / unit_factor. + + ``Potential`` returns stresses in GPa = (1/V) * eV/A^3 * 160.21766208. + The wrapper returns the raw strain-grad tensor (no /V, no unit factor). + So ``virial_wrapper == -stress_potential * V / 160.21766208`` (signs from + LAMMPS sign convention). + """ + potential, element_types = tiny_potential + structure = mo_s_supercell + cutoff = 4.0 + dtype = matgl.float_th + + _e_ref, _f_ref, s_ref = _potential_reference(potential, structure, element_types, cutoff, dtype) + + wrapper = LAMMPSMatGLModel(potential=potential, dtype=dtype) + wrapper.eval() + + pos, eidx, ushifts, cell, z, local = _build_lammps_inputs(structure, element_types, cutoff, dtype) + out = wrapper(pos, eidx, ushifts, cell, z, local, compute_virials=True) + + volume = float(np.linalg.det(structure.lattice.matrix)) + expected_virial = s_ref.detach() * volume / 160.21766208 + + diff = (out["virials"].detach() - expected_virial).abs().max().item() + assert diff < 1e-3, f"virial mismatch (max abs diff = {diff})"