Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c3e9ce9
checkpoint
njzjz May 25, 2025
ac5a7a3
checkpoint
njzjz May 26, 2025
b99d66d
Merge remote-tracking branch 'origin/devel' into jax_training
njzjz May 26, 2025
ad20de9
fix(jax): make display_if_exist jit-able
njzjz May 26, 2025
c62c356
fix(jax): workaround for "xxTracer is not a valid JAX type"
njzjz May 27, 2025
d0a1ce7
checkpoint
njzjz May 27, 2025
e88838b
fix scale of initial parameters
njzjz May 27, 2025
29922d7
set up lr
njzjz May 27, 2025
d5d5f06
clean up tqdm
njzjz May 27, 2025
6947ee6
freeze
njzjz May 27, 2025
9218157
improve checkpoint
njzjz May 28, 2025
e3dca7a
fix unreference variable
njzjz May 28, 2025
1fdb40c
hessian loss
njzjz May 28, 2025
1474327
valid_more_loss
njzjz May 28, 2025
d7f06d6
Merge branch 'devel' into jax_training
njzjz Jun 1, 2025
68b3727
print summary
njzjz Jun 1, 2025
7930827
Merge branch 'jax_training' of https://github.com/njzjz/deepmd-kit in…
njzjz Jun 1, 2025
a0cd67a
seed
njzjz Jun 1, 2025
d21c39c
restart. bug to be fixed
njzjz Jun 3, 2025
15bb506
fix lr
njzjz Jun 3, 2025
b9111c8
fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable
njzjz Jun 10, 2025
7d7e043
bump the version of Layer data
njzjz Jun 10, 2025
77cc091
fix pd trainable
njzjz Jun 10, 2025
4c635c5
Merge branch 'trainable' into jax_training
njzjz Jun 10, 2025
6bd237d
fix(jax): fix DPA3 force NaN with edge_init_use_dist
njzjz Jun 10, 2025
5426789
Merge branch 'fix-jax-dpa3-grad-nan' into jax_training
njzjz Jun 10, 2025
2433566
fix(jax): use more safe_for_vector_norm
njzjz Jun 17, 2025
dfaf6fb
fix nopbc behavior
njzjz Jun 20, 2025
b76fb83
should be >1
njzjz Jun 20, 2025
4123ac2
freeze with hessian
njzjz Jun 25, 2025
1cd66ee
pass mixed_type
njzjz Jul 10, 2025
cb7f341
Merge remote-tracking branch 'origin/devel' into jax_training
njzjz Jul 27, 2025
208b648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.entrypoints.main import (
main,
)

return main

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def __init__(
self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision])
self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision])
self.orig_sel = self.sel
self.ndescrpt = self.nnei * 4

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
72 changes: 72 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
Expand All @@ -17,6 +19,10 @@
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)

from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -86,3 +92,69 @@ def serialize(self) -> dict:
**super().serialize(),
"type": "ener",
}

def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None:
"""Compute the output statistics.

Parameters
----------
all_stat
must have the following components:
all_stat['energy'] of shape n_sys x n_batch x n_frame
can be prepared by model.make_stat_input
mixed_type
Whether to perform the mixed_type mode.
If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
in which frames in a system may have different natoms_vec(s), with the same nloc.
"""
self.bias_atom_e = self._compute_output_stats(
all_stat, rcond=self.rcond, mixed_type=mixed_type
)

def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
data = all_stat["energy"]
# data[sys_idx][batch_idx][frame_idx]
sys_ener = []
for ss in range(len(data)):
sys_data = []
for ii in range(len(data[ss])):
for jj in range(len(data[ss][ii])):
sys_data.append(data[ss][ii][jj])
sys_data = np.concatenate(sys_data)
sys_ener.append(np.average(sys_data))
sys_ener = np.array(sys_ener)
sys_tynatom = []
if mixed_type:
data = all_stat["real_natoms_vec"]
nsys = len(data)
for ss in range(len(data)):
tmp_tynatom = []
for ii in range(len(data[ss])):
for jj in range(len(data[ss][ii])):
tmp_tynatom.append(data[ss][ii][jj].astype(np.float64))
tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0)
sys_tynatom.append(tmp_tynatom)
else:
data = all_stat["natoms_vec"]
nsys = len(data)
for ss in range(len(data)):
sys_tynatom.append(data[ss][0].astype(np.float64))
sys_tynatom = np.array(sys_tynatom)
sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
sys_tynatom = sys_tynatom[:, 2:]
if len(self.atom_ener) > 0:
# Atomic energies stats are incorrect if atomic energies are assigned.
# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_atom_ener = np.array(
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined attribute reference.

The code references self.atom_ener_v which is not defined in the class. This will cause an AttributeError at runtime.

Based on the similar implementation in deepmd/tf/fit/ener.py, this should likely be self.atom_ener:

-                [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+                [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
[ee if ee is not None else np.nan for ee in self.atom_ener]
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.

)
else:
assigned_atom_ener = None
energy_shift, _ = compute_stats_from_redu(
sys_ener.reshape(-1, 1),
sys_tynatom,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
return energy_shift.ravel()
84 changes: 82 additions & 2 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def call(
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy)
more_loss["rmse_e"] = self.display_if_exist(
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
if not self.use_huber:
Expand All @@ -189,7 +191,9 @@ def call(
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
more_loss["rmse_f"] = self.display_if_exist(
xp.sqrt(l2_force_loss), find_force
)
if self.has_v:
virial_reshape = xp.reshape(virial, (-1,))
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
Expand Down Expand Up @@ -381,3 +385,79 @@ def deserialize(cls, data: dict) -> "Loss":
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
return cls(**data)


class EnergyHessianLoss(EnergyLoss):
def __init__(
self,
start_pref_h=0.0,
limit_pref_h=0.0,
**kwargs,
):
r"""Enable the layer to compute loss on hessian.

Parameters
----------
start_pref_h : float
The prefactor of hessian loss at the start of the training.
limit_pref_h : float
The prefactor of hessian loss at the end of the training.
**kwargs
Other keyword arguments.
"""
EnergyLoss.__init__(self, **kwargs)
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix logical condition for enabling Hessian loss.

The condition uses and which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.

Apply this fix:

-        self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+        self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.


self.start_pref_h = start_pref_h
self.limit_pref_h = limit_pref_h

def call(
self,
learning_rate: float,
natoms: int,
model_dict: dict[str, np.ndarray],
label_dict: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Calculate loss from model results and labeled results."""
loss, more_loss = EnergyLoss.call(
self, learning_rate, natoms, model_dict, label_dict
)
xp = array_api_compat.array_namespace(model_dict["energy"])
coef = learning_rate / self.starter_learning_rate
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef

if (
self.has_h
and "energy_derv_r_derv_r" in model_dict
and "hessian" in label_dict
):
find_hessian = label_dict.get("find_hessian", 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify find_hessian data type consistency.

The code uses label_dict.get("find_hessian", 0.0) which returns a scalar default, but other find_* variables in the parent class appear to be arrays from the labeled data.

Let me verify the expected data type for find_* variables:


🏁 Script executed:

#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py

Length of output: 1419


🏁 Script executed:

#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
  "find_$_",
  $$$
)'

Length of output: 64


Ensure find_hessian is accessed like the other find_* flags

The other find_* variables are pulled directly from label_dict (no default), so using .get(..., 0.0) here is inconsistent and may silently disable Hessian loss even when "hessian" is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).

• File deepmd/dpmodel/loss/ener.py, line 434

-    find_hessian = label_dict.get("find_hessian", 0.0)
+    find_hessian = label_dict["find_hessian"]

If you do need a default, consider instead:

find_hessian = label_dict.get(
    "find_hessian",
    xp.ones(label_dict["hessian"].shape[0])  # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.

pref_h = pref_h * find_hessian
diff_h = label_dict["hessian"].reshape(
-1,
) - model_dict["energy_derv_r_derv_r"].reshape(
-1,
)
l2_hessian_loss = xp.mean(xp.square(diff_h))
loss += pref_h * l2_hessian_loss
rmse_h = xp.sqrt(l2_hessian_loss)
more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian)

more_loss["rmse"] = xp.sqrt(loss)
return loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Add hessian label requirement needed for this loss calculation."""
label_requirement = super().label_requirement
if self.has_h:
label_requirement.append(
DataRequirementItem(
"hessian",
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement
7 changes: 5 additions & 2 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,15 @@ def iter(
"last_dim should be 1 for raial-only or 4 for full descriptor."
)
for system in data:
coord, atype, box, natoms = (
coord, atype, box = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
coord = xp.reshape(coord, (coord.shape[0], -1, 3)) # (nframes, nloc, 3)
atype = xp.reshape(atype, (coord.shape[0], -1)) # (nframes, nloc)
if box is not None:
box = xp.reshape(box, (coord.shape[0], 3, 3))
(
extended_coord,
extended_atype,
Expand Down
7 changes: 3 additions & 4 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def __init__(
self.decay_rate = decay_rate
self.min_lr = stop_lr

def value(self, step) -> np.float64:
def value(self, step, xp=np) -> np.float64:
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
step_lr = self.min_lr
step_lr = self.start_lr * xp.power(self.decay_rate, step // self.decay_steps)
step_lr = xp.clip(step_lr, self.min_lr, None)
return step_lr
1 change: 1 addition & 0 deletions deepmd/jax/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
39 changes: 39 additions & 0 deletions deepmd/jax/entrypoints/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (
Path,
)

from deepmd.jax.utils.serialization import (
deserialize_to_file,
serialize_from_file,
)


def freeze(
*,
checkpoint_folder: str,
output: str,
hessian: bool = False,
**kwargs,
) -> None:
"""Freeze the graph in supplied folder.

Parameters
----------
checkpoint_folder : str
location of either the folder with checkpoint or the checkpoint prefix
output : str
output file name
hessian : bool, optional
whether to freeze the hessian, by default False
**kwargs
other arguments
"""
if (Path(checkpoint_folder) / "checkpoint").is_file():
checkpoint_meta = Path(checkpoint_folder) / "checkpoint"
checkpoint_folder = checkpoint_meta.read_text().strip()
if Path(checkpoint_folder).is_dir():
data = serialize_from_file(checkpoint_folder)
deserialize_to_file(output, data, hessian=hessian)
else:
raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")
67 changes: 67 additions & 0 deletions deepmd/jax/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD-Kit entry point module."""

import argparse
from pathlib import (
Path,
)
from typing import (
Optional,
Union,
)

from deepmd.backend.suffix import (
format_model_suffix,
)
from deepmd.jax.entrypoints.freeze import (
freeze,
)
from deepmd.jax.entrypoints.train import (
train,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.main import (
parse_args,
)

__all__ = ["main"]


def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
"""DeePMD-Kit entry point.
Parameters
----------
args : list[str] or argparse.Namespace, optional
list of command line arguments, used to avoid calling from the subprocess,
as it is quite slow to import tensorflow; if Namespace is given, it will
be used directly
Raises
------
RuntimeError
if no command was input
"""
if not isinstance(args, argparse.Namespace):
args = parse_args(args=args)

dict_args = vars(args)
set_log_handles(
args.log_level,
Path(args.log_path) if args.log_path else None,
mpi_log=None,
)

if args.command == "train":
train(**dict_args)
elif args.command == "freeze":
dict_args["output"] = format_model_suffix(
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
)
freeze(**dict_args)
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")
Comment on lines +64 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve command handling for better error reporting.

The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.

Apply this diff to improve error handling:

-    elif args.command is None:
-        pass
+    elif args.command is None:
+        raise RuntimeError("No command specified. Available commands: train, freeze")
     else:
-        raise RuntimeError(f"unknown command {args.command}")
+        raise RuntimeError(
+            f"Unknown command '{args.command}'. Available commands: train, freeze"
+        )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.

Loading
Loading