Skip to content

Fix charges output and total_charge handling for QET in PESCalculator and JAX path#803

Merged
shyuep merged 1 commit into
materialyzeai:mainfrom
YiqingChen524:fix-pescalculator-charges
Jun 10, 2026
Merged

Fix charges output and total_charge handling for QET in PESCalculator and JAX path#803
shyuep merged 1 commit into
materialyzeai:mainfrom
YiqingChen524:fix-pescalculator-charges

Conversation

@YiqingChen524

Copy link
Copy Markdown
Collaborator

When the DGL backend was removed, the only calculator with charge plumbing went with it. For charge-predicting potentials (QET):

  • get_charges() raised PropertyNotImplementedError on both PESCalculator and JAXPESCalculator: "charges" was missing from implemented_properties and never stored in results, even though the torch Potential returns them and the JAX QEq solve computes them.
  • total_charge was silently ignored everywhere: the torch calculator never passed it (LinearQeq treats None as zero) and the JAX conversion baked it as a static 0.0, so charged cells ran as neutral with no error. atoms.set_initial_charges(), which the DGL calculator consumed, was also ignored.

Changes:

  • PESCalculator: "charges" in implemented_properties, results populated when potential.calc_charge is set, total_charge/ext_pot kwargs restored with the atoms.get_initial_charges() fallback. Relaxer and MolecularDynamics inherit this automatically.
  • make_potential_fn(): public (E, forces, stress) 3-tuple contract unchanged by default. with_charges=True (QET only) returns a callable with a trailing total_charge argument that also returns per-atom charges; total_charge is traced, so varying it triggers no recompilation.
  • qet_energy_and_charges() returns (energy, charges) and accepts a traced total_charge; qet_energy() stays a thin wrapper.
  • JAXPESCalculator mirrors PESCalculator: "charges" exposed, total_charge kwarg with the initial-charges fallback.

… and JAX path

When the DGL backend was removed, the only calculator with charge
plumbing went with it. For charge-predicting potentials (QET):

- get_charges() raised PropertyNotImplementedError on both
  PESCalculator and JAXPESCalculator: "charges" was missing from
  implemented_properties and never stored in results, even though the
  torch Potential returns them and the JAX QEq solve computes them.
- total_charge was silently ignored everywhere: the torch calculator
  never passed it (LinearQeq treats None as zero) and the JAX
  conversion baked it as a static 0.0, so charged cells ran as neutral
  with no error. atoms.set_initial_charges(), which the DGL calculator
  consumed, was also ignored.

Changes:

- PESCalculator: "charges" in implemented_properties, results
  populated when potential.calc_charge is set, total_charge/ext_pot
  kwargs restored with the atoms.get_initial_charges() fallback.
  Relaxer and MolecularDynamics inherit this automatically.
- make_potential_fn(): public (E, forces, stress) 3-tuple contract
  unchanged by default. with_charges=True (QET only) returns a callable
  with a trailing total_charge argument that also returns per-atom
  charges; total_charge is traced, so varying it triggers no
  recompilation.
- qet_energy_and_charges() returns (energy, charges) and accepts a
  traced total_charge; qet_energy() stays a thin wrapper.
- JAXPESCalculator mirrors PESCalculator: "charges" exposed,
  total_charge kwarg with the initial-charges fallback.
@shyuep shyuep merged commit d698b21 into materialyzeai:main Jun 10, 2026
6 of 8 checks passed
@YiqingChen524 YiqingChen524 mentioned this pull request Jun 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants