diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95b4a9a7..5ea4ece4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,6 @@ jobs: os: [ubuntu-latest, macos-14, windows-latest] version: - { python: "3.10", resolution: highest } - - { python: "3.12", resolution: lowest-direct } runs-on: ${{ matrix.os }} steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc3acb2d..967d7952 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,39 +9,51 @@ repos: - id: ruff args: [--fix] types_or: [python, jupyter] + exclude: ^site/ - id: ruff-format types_or: [python, jupyter] + exclude: ^site/ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - id: check-case-conflict + exclude: ^site/ - id: check-symlinks + exclude: ^site/ - id: check-yaml + exclude: ^site/ - id: destroyed-symlinks + exclude: ^site/ - id: end-of-file-fixer exclude_types: [jupyter] + exclude: ^site/ - id: mixed-line-ending + exclude: ^site/ - id: trailing-whitespace + exclude: ^site/ - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: - id: codespell stages: [pre-commit, commit-msg] - args: [--check-filenames] + args: [--check-filenames, --skip=*.lock, package-lock.json] + exclude: ^site/ - repo: https://github.com/kynan/nbstripout rev: 0.8.0 hooks: - id: nbstripout args: [--drop-empty-cells, --keep-output] + exclude: ^site/ - repo: https://github.com/pre-commit/mirrors-prettier rev: v4.0.0-alpha.8 hooks: - id: prettier args: [--write] # edit files in-place + exclude: ^site/ additional_dependencies: - prettier - prettier-plugin-svelte @@ -54,6 +66,7 @@ repos: types: [file] args: [--fix, --config, site/eslint.config.js] files: \.(js|ts|svelte)$ + exclude: ^site/ additional_dependencies: - eslint - eslint-plugin-svelte diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index b87ce668..112bec1e 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -210,6 +210,8 @@ def initialize_from(self, dataset: str) -> None: self.initialize_from_MPtrj() elif dataset == "MPF": self.initialize_from_MPF() + elif dataset == "MP-r2SCAN": + self.initialize_from_mp_r2scan() else: raise NotImplementedError(f"{dataset=} not supported yet") @@ -423,6 +425,113 @@ def initialize_from_MPF(self) -> None: # noqa: N802 self.is_intensive = False self.fitted = True + def initialize_from_mp_r2scan(self) -> None: + """Initialize pre-fitted weights from MP-r2SCAN dataset.""" + state_dict = collections.OrderedDict() + + state_dict["weight"] = torch.tensor( + [ + -3.4690e00, + -3.0982e-01, + -3.3199e00, + -4.7963e00, + -8.0507e00, + -9.5759e00, + -9.8677e00, + -9.1242e00, + -6.7546e00, + -1.9120e00, + -4.5438e00, + -4.0474e00, + -7.2176e00, + -9.6473e00, + -9.6514e00, + -9.5449e00, + -7.9040e00, + -4.8555e00, + -7.0955e00, + -8.4121e00, + -1.2896e01, + -1.4512e01, + -1.5121e01, + -1.5248e01, + -1.4923e01, + -1.4040e01, + -1.2751e01, + -1.1945e01, + -1.0464e01, + -8.9017e00, + -1.1722e01, + -1.4170e01, + -1.5067e01, + -1.5418e01, + -1.4794e01, + -1.1486e01, + -1.5029e01, + -1.6974e01, + -2.1922e01, + -2.4265e01, + -2.5605e01, + -2.6075e01, + -2.5442e01, + -2.5286e01, + -2.4571e01, + -2.3376e01, + -2.0786e01, + -2.0013e01, + -2.2626e01, + -2.4799e01, + -2.5832e01, + -2.5982e01, + -2.5459e01, + -2.2229e01, + -2.6402e01, + -2.8426e01, + -3.1738e01, + -3.2878e01, + -3.0945e01, + -3.0967e01, + -2.9942e01, + -3.1421e01, + -4.0080e01, + -4.5251e01, + -3.2790e01, + -3.3584e01, + -3.4371e01, + -3.5534e01, + -3.6623e01, + 5.6469e-14, + -3.9644e01, + -4.6709e01, + -4.9586e01, + -5.1200e01, + -5.1762e01, + -5.2404e01, + -5.2657e01, + -5.2166e01, + -5.0671e01, + -4.8918e01, + -5.2844e01, + -5.6015e01, + -5.8066e01, + 1.8537e-14, + -1.0885e-15, + -1.0417e-16, + -2.1228e-16, + 5.6561e-16, + -6.9083e01, + -7.4960e01, + -7.8234e01, + -8.1985e01, + -8.4724e01, + -8.7538e01, + ] + ).view([1, 94]) + + self.fc.load_state_dict(state_dict) + self.is_intensive = False + self.fitted = True + def initialize_from_numpy(self, file_name: str | Path) -> None: """Initialize pre-fitted weights from numpy file.""" atom_ref_np = np.load(file_name) diff --git a/chgnet/pretrained/r2scan/README.md b/chgnet/pretrained/r2scan/README.md new file mode 100644 index 00000000..eb46cd06 --- /dev/null +++ b/chgnet/pretrained/r2scan/README.md @@ -0,0 +1,93 @@ +## Model r2SCAN + +This is the pretrained weights of CHGNet fine-tuned on the MP-r2SCAN dataset. The model was initialized from the GGA/GGA+U trained CHGNet v0.3.0 and then transferred to the R2SCAN functional dataset. This work is published in the npj Computational Materials paper titled "Cross-functional transferability in foundation machine learning interatomic potentials." + +All experiments and results shown in the paper (Method 4) were performed with this version of weights. + +Date: 9/15/2025 + +Author: Xu Huang + +## Model Parameters + +```python +model = CHGNet( + atom_fea_dim=64, + bond_fea_dim=64, + angle_fea_dim=64, + composition_model="MP-r2SCAN", + num_radial=31, + num_angular=31, + n_conv=4, + atom_conv_hidden_dim=64, + update_bond=True, + bond_conv_hidden_dim=64, + update_angle=True, + angle_layer_hidden_dim=0, + conv_dropout=0, + read_out="ave", + gMLP_norm='layer', + readout_norm='layer', + mlp_hidden_dims=[64, 64, 64], + mlp_first=True, + is_intensive=True, + non_linearity="silu", + atom_graph_cutoff=6, + bond_graph_cutoff=3, + graph_converter_algorithm="fast", + cutoff_coeff=8, + learnable_rbf=True, +) +``` + +## Dataset Used + +MP-r2SCAN dataset (https://doi.org/10.6084/m9.figshare.28245650.v2) with 8-1-1 train-val-test splitting + +## Training Configuration + +We used the pretrained CHGNet v0.3.0 as the starting model and fine-tuned it on the MP-r2SCAN dataset. + +```python +# Load pretrained CHGNet v0.3.0 +chgnet = CHGNet.load() + +# Update model_args to reflect the new composition model +chgnet.model_args['composition_model'] = "MP-r2SCAN" + +# Reinitialize composition model weights for MP-r2SCAN dataset +chgnet.composition_model.initialize_from("MP-r2SCAN") + +# Initialize trainer with specific configuration +trainer = Trainer( + model=chgnet, + targets='efsm', + energy_loss_ratio=3, + force_loss_ratio=1, + stress_loss_ratio=0.1, + mag_loss_ratio=1, + optimizer='Adam', + scheduler='CosLR', + criterion='Huber', + epochs=50, + learning_rate=1e-3, + use_device='cuda' +) + +# Fine-tune the model +trainer.train( + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + save_dir=save_dir, + train_composition_model=False +) +``` + +## Mean Absolute Error (MAE) logs + +| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | +| ---------- | ----------------- | ------------- | ------------ | ------------ | +| Train | 11.82 | 24.55 | 0.082 | 0.021 | +| Validation | 15.48 | 36.50 | 0.161 | 0.023 | +| Test | 16.76 | 38.46 | 0.167 | 0.023 | diff --git a/chgnet/pretrained/r2scan/chgnet_r2scan_transfer_learning_e15f36s161m23.pth.tar b/chgnet/pretrained/r2scan/chgnet_r2scan_transfer_learning_e15f36s161m23.pth.tar new file mode 100644 index 00000000..57889af4 Binary files /dev/null and b/chgnet/pretrained/r2scan/chgnet_r2scan_transfer_learning_e15f36s161m23.pth.tar differ diff --git a/chgnet/utils/common_utils.py b/chgnet/utils/common_utils.py index 94079902..6c53a494 100644 --- a/chgnet/utils/common_utils.py +++ b/chgnet/utils/common_utils.py @@ -117,17 +117,21 @@ def write_json(dct: dict, filepath: str) -> dict: filepath (str): file name of JSON to write. """ - def handler(obj: object) -> int | object: - """Convert numpy int64 to int. + def handler(obj: object) -> int | float | list | object: + """Convert numpy types to JSON serializable types. Fixes TypeError: Object of type int64 is not JSON serializable reported in https://github.com/CederGroupHub/chgnet/issues/168. Returns: - int | object: object for serialization + int | float | list | object: object for serialization """ if isinstance(obj, np.integer): return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() return obj with open(filepath, mode="w") as file: diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index 4022e704..908f9894 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -136,20 +136,20 @@ def test_crystal_graph_perturb_legacy(): print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [420, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 - - assert list(graph.bond_graph.shape) == [850, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert list(graph.atom_graph.shape) == [392, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 + + assert list(graph.bond_graph.shape) == [732, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 0 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 0 assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [210] - assert list(graph.directed2undirected.shape) == [420] + assert list(graph.undirected2directed.shape) == [196] + assert list(graph.directed2undirected.shape) == [392] def test_crystal_graph_perturb_fast(): @@ -163,20 +163,20 @@ def test_crystal_graph_perturb_fast(): print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [420, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 - - assert list(graph.bond_graph.shape) == [850, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert list(graph.atom_graph.shape) == [392, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 + + assert list(graph.bond_graph.shape) == [732, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 0 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 0 assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [210] - assert list(graph.directed2undirected.shape) == [420] + assert list(graph.undirected2directed.shape) == [196] + assert list(graph.directed2undirected.shape) == [392] def test_crystal_graph_isotropic_strained_legacy():