From ce0e64d69f109b184b68c2799e57dbb707748bc1 Mon Sep 17 00:00:00 2001 From: "xiaotong.liu" Date: Fri, 24 Dec 2021 08:49:13 +0800 Subject: [PATCH] add a convertion step before GraphCrystalDisordered training Signed-off-by: xiaotong.liu --- megnet/data/graph.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/megnet/data/graph.py b/megnet/data/graph.py index 8fe2744b3..f086b2344 100644 --- a/megnet/data/graph.py +++ b/megnet/data/graph.py @@ -580,6 +580,13 @@ def _generate_inputs(self, batch_index: list) -> tuple: - [ndarray]: List of indices for the start of each bond - [ndarray]: List of indices for the end of each bond """ + import collections + if isinstance(self.atom_features[0][0], collections.defaultdict): + from megnet.data.crystal import CrystalGraphDisordered + cgd = CrystalGraphDisordered() + for i in range(len(self.atom_features)): + self.atom_features[i] = cgd.atom_converter.convert( + self.atom_features[i].tolist()).tolist() # Get the features and connectivity lists for this batch feature_list_temp = itemgetter_list(self.atom_features, batch_index)