Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
os: [ubuntu-latest, macos-14, windows-latest]
version:
- { python: "3.10", resolution: highest }
- { python: "3.12", resolution: lowest-direct }
runs-on: ${{ matrix.os }}

steps:
Expand Down
15 changes: 14 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,51 @@ repos:
- id: ruff
args: [--fix]
types_or: [python, jupyter]
exclude: ^site/
- id: ruff-format
types_or: [python, jupyter]
exclude: ^site/

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-case-conflict
exclude: ^site/
- id: check-symlinks
exclude: ^site/
- id: check-yaml
exclude: ^site/
- id: destroyed-symlinks
exclude: ^site/
- id: end-of-file-fixer
exclude_types: [jupyter]
exclude: ^site/
- id: mixed-line-ending
exclude: ^site/
- id: trailing-whitespace
exclude: ^site/

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
stages: [pre-commit, commit-msg]
args: [--check-filenames]
args: [--check-filenames, --skip=*.lock, package-lock.json]
exclude: ^site/

- repo: https://github.com/kynan/nbstripout
rev: 0.8.0
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
exclude: ^site/

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
args: [--write] # edit files in-place
exclude: ^site/
additional_dependencies:
- prettier
- prettier-plugin-svelte
Expand All @@ -54,6 +66,7 @@ repos:
types: [file]
args: [--fix, --config, site/eslint.config.js]
files: \.(js|ts|svelte)$
exclude: ^site/
additional_dependencies:
- eslint
- eslint-plugin-svelte
Expand Down
109 changes: 109 additions & 0 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def initialize_from(self, dataset: str) -> None:
self.initialize_from_MPtrj()
elif dataset == "MPF":
self.initialize_from_MPF()
elif dataset == "MP-r2SCAN":
self.initialize_from_mp_r2scan()
else:
raise NotImplementedError(f"{dataset=} not supported yet")

Expand Down Expand Up @@ -423,6 +425,113 @@ def initialize_from_MPF(self) -> None: # noqa: N802
self.is_intensive = False
self.fitted = True

def initialize_from_mp_r2scan(self) -> None:
"""Initialize pre-fitted weights from MP-r2SCAN dataset."""
state_dict = collections.OrderedDict()

state_dict["weight"] = torch.tensor(
[
-3.4690e00,
-3.0982e-01,
-3.3199e00,
-4.7963e00,
-8.0507e00,
-9.5759e00,
-9.8677e00,
-9.1242e00,
-6.7546e00,
-1.9120e00,
-4.5438e00,
-4.0474e00,
-7.2176e00,
-9.6473e00,
-9.6514e00,
-9.5449e00,
-7.9040e00,
-4.8555e00,
-7.0955e00,
-8.4121e00,
-1.2896e01,
-1.4512e01,
-1.5121e01,
-1.5248e01,
-1.4923e01,
-1.4040e01,
-1.2751e01,
-1.1945e01,
-1.0464e01,
-8.9017e00,
-1.1722e01,
-1.4170e01,
-1.5067e01,
-1.5418e01,
-1.4794e01,
-1.1486e01,
-1.5029e01,
-1.6974e01,
-2.1922e01,
-2.4265e01,
-2.5605e01,
-2.6075e01,
-2.5442e01,
-2.5286e01,
-2.4571e01,
-2.3376e01,
-2.0786e01,
-2.0013e01,
-2.2626e01,
-2.4799e01,
-2.5832e01,
-2.5982e01,
-2.5459e01,
-2.2229e01,
-2.6402e01,
-2.8426e01,
-3.1738e01,
-3.2878e01,
-3.0945e01,
-3.0967e01,
-2.9942e01,
-3.1421e01,
-4.0080e01,
-4.5251e01,
-3.2790e01,
-3.3584e01,
-3.4371e01,
-3.5534e01,
-3.6623e01,
5.6469e-14,
-3.9644e01,
-4.6709e01,
-4.9586e01,
-5.1200e01,
-5.1762e01,
-5.2404e01,
-5.2657e01,
-5.2166e01,
-5.0671e01,
-4.8918e01,
-5.2844e01,
-5.6015e01,
-5.8066e01,
1.8537e-14,
-1.0885e-15,
-1.0417e-16,
-2.1228e-16,
5.6561e-16,
-6.9083e01,
-7.4960e01,
-7.8234e01,
-8.1985e01,
-8.4724e01,
-8.7538e01,
]
).view([1, 94])

self.fc.load_state_dict(state_dict)
self.is_intensive = False
self.fitted = True

def initialize_from_numpy(self, file_name: str | Path) -> None:
"""Initialize pre-fitted weights from numpy file."""
atom_ref_np = np.load(file_name)
Expand Down
93 changes: 93 additions & 0 deletions chgnet/pretrained/r2scan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## Model r2SCAN

This is the pretrained weights of CHGNet fine-tuned on the MP-r2SCAN dataset. The model was initialized from the GGA/GGA+U trained CHGNet v0.3.0 and then transferred to the R2SCAN functional dataset. This work is published in the npj Computational Materials paper titled "Cross-functional transferability in foundation machine learning interatomic potentials."

All experiments and results shown in the paper (Method 4) were performed with this version of weights.

Date: 9/15/2025

Author: Xu Huang

## Model Parameters

```python
model = CHGNet(
atom_fea_dim=64,
bond_fea_dim=64,
angle_fea_dim=64,
composition_model="MP-r2SCAN",
num_radial=31,
num_angular=31,
n_conv=4,
atom_conv_hidden_dim=64,
update_bond=True,
bond_conv_hidden_dim=64,
update_angle=True,
angle_layer_hidden_dim=0,
conv_dropout=0,
read_out="ave",
gMLP_norm='layer',
readout_norm='layer',
mlp_hidden_dims=[64, 64, 64],
mlp_first=True,
is_intensive=True,
non_linearity="silu",
atom_graph_cutoff=6,
bond_graph_cutoff=3,
graph_converter_algorithm="fast",
cutoff_coeff=8,
learnable_rbf=True,
)
```

## Dataset Used

MP-r2SCAN dataset (https://doi.org/10.6084/m9.figshare.28245650.v2) with 8-1-1 train-val-test splitting

## Training Configuration

We used the pretrained CHGNet v0.3.0 as the starting model and fine-tuned it on the MP-r2SCAN dataset.

```python
# Load pretrained CHGNet v0.3.0
chgnet = CHGNet.load()

# Update model_args to reflect the new composition model
chgnet.model_args['composition_model'] = "MP-r2SCAN"

# Reinitialize composition model weights for MP-r2SCAN dataset
chgnet.composition_model.initialize_from("MP-r2SCAN")

# Initialize trainer with specific configuration
trainer = Trainer(
model=chgnet,
targets='efsm',
energy_loss_ratio=3,
force_loss_ratio=1,
stress_loss_ratio=0.1,
mag_loss_ratio=1,
optimizer='Adam',
scheduler='CosLR',
criterion='Huber',
epochs=50,
learning_rate=1e-3,
use_device='cuda'
)

# Fine-tune the model
trainer.train(
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
save_dir=save_dir,
train_composition_model=False
)
```

## Mean Absolute Error (MAE) logs

| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) |
| ---------- | ----------------- | ------------- | ------------ | ------------ |
| Train | 11.82 | 24.55 | 0.082 | 0.021 |
| Validation | 15.48 | 36.50 | 0.161 | 0.023 |
| Test | 16.76 | 38.46 | 0.167 | 0.023 |
Binary file not shown.
10 changes: 7 additions & 3 deletions chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,21 @@ def write_json(dct: dict, filepath: str) -> dict:
filepath (str): file name of JSON to write.
"""

def handler(obj: object) -> int | object:
"""Convert numpy int64 to int.
def handler(obj: object) -> int | float | list | object:
"""Convert numpy types to JSON serializable types.

Fixes TypeError: Object of type int64 is not JSON serializable
reported in https://github.com/CederGroupHub/chgnet/issues/168.

Returns:
int | object: object for serialization
int | float | list | object: object for serialization
"""
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj

with open(filepath, mode="w") as file:
Expand Down
44 changes: 22 additions & 22 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,20 @@ def test_crystal_graph_perturb_legacy():
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert list(graph.atom_graph.shape) == [392, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 48
assert (graph.atom_graph[:, 1] == 3).sum().item() == 48
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [732, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 0
assert (graph.bond_graph[:, 3] == 36).sum().item() == 0
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]
assert list(graph.undirected2directed.shape) == [196]
assert list(graph.directed2undirected.shape) == [392]


def test_crystal_graph_perturb_fast():
Expand All @@ -163,20 +163,20 @@ def test_crystal_graph_perturb_fast():
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert list(graph.atom_graph.shape) == [392, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 48
assert (graph.atom_graph[:, 1] == 3).sum().item() == 48
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [732, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 0
assert (graph.bond_graph[:, 3] == 36).sum().item() == 0
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]
assert list(graph.undirected2directed.shape) == [196]
assert list(graph.directed2undirected.shape) == [392]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down
Loading