Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ from chgnet.model.model import CHGNet
# Load the latest CHGNet model (default: 0.3.0)
chgnet = CHGNet.load()
# Load specific CHGNet versions
chgnet_r2scan = CHGNet.load('r2scan')
chgnet = CHGNet.load(model_name='r2scan')
```

**Model Details:**
Expand Down
10 changes: 8 additions & 2 deletions chgnet/pretrained/r2scan/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Model r2SCAN

This is the pretrained weights of CHGNet fine-tuned on the MP-r2SCAN dataset. The model was initialized from the GGA/GGA+U trained CHGNet v0.3.0 and then transferred to the R2SCAN functional dataset. This work is published in the npj Computational Materials paper titled "Cross-functional transferability in foundation machine learning interatomic potentials."
This is the pretrained weights of CHGNet fine-tuned on the MP-r2SCAN dataset. The model was initialized from the GGA/GGA+U trained CHGNet v0.3.0 and then transferred to the R2SCAN functional dataset. This work is published in the npj Computational Materials paper titled "Cross-functional transferability in foundation machine learning interatomic potentials".

All experiments and results shown in the paper (Method 4) were performed with this version of weights.

Date: 9/15/2025
Date: 9/21/2025

Author: Xu Huang

Expand Down Expand Up @@ -44,6 +44,12 @@ model = CHGNet(

MP-r2SCAN dataset (https://doi.org/10.6084/m9.figshare.28245650.v2) with 8-1-1 train-val-test splitting

## Load the Model

```python
chgnet = CHGNet.load(model_name='r2scan')
```

## Training Configuration

We used the pretrained CHGNet v0.3.0 as the starting model and fine-tuned it on the MP-r2SCAN dataset.
Expand Down
Binary file not shown.
51 changes: 51 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,54 @@ def test_model_load_version_params(
# check check_cuda_mem defaults to False
inspect_signature = inspect.signature(CHGNet.load)
assert inspect_signature.parameters["check_cuda_mem"].default is False


def test_model_load_r2scan(capsys: pytest.CaptureFixture) -> None:
"""Test loading the r2scan pretrained model."""
model = CHGNet.load(model_name="r2scan", use_device="cpu")
r2scan_key, r2scan_params = "r2scan", 412_525
assert model.version == r2scan_key
assert model.n_params == r2scan_params
stdout, stderr = capsys.readouterr()

assert stdout == (
f"CHGNet v{r2scan_key} initialized with {r2scan_params:,} parameters\n"
"CHGNet will run on cpu\n"
)
assert stderr == ""


def test_model_load_all_pretrained_models() -> None:
"""Test loading all three pretrained models."""
# Test default model (0.3.0)
model_030 = CHGNet.load(use_device="cpu")
assert model_030.version == "0.3.0"
assert model_030.n_params == 412_525

# Test 0.2.0 model
model_020 = CHGNet.load(model_name="0.2.0", use_device="cpu")
assert model_020.version == "0.2.0"
assert model_020.n_params == 400_438

# Test r2scan model
model_r2scan = CHGNet.load(model_name="r2scan", use_device="cpu")
assert model_r2scan.version == "r2scan"
assert model_r2scan.n_params == 412_525

# Test that all models can make predictions
from pymatgen.core import Structure

from chgnet import ROOT

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter()
graph = converter(structure, graph_id="test-all-models")

# Test prediction with all models
for model in [model_030, model_020, model_r2scan]:
prediction = model.predict_graph(graph, task="e")
assert "e" in prediction
# prediction["e"] is a numpy array, convert to float for assertion
energy = float(prediction["e"])
assert isinstance(energy, float)
assert energy < 0 # Energy should be negative
Loading