diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index 8660b9c6d..8b0730f6c 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -116,8 +116,13 @@ def to_json(self) -> Dict[str, Union[str, int]]: def from_json( cls, json_dict: Dict[str, Union[str, int]] ) -> NestedKeyMetadataEntry: + + key_name = json_dict[_KEY_NAME] + if isinstance(key_name, str) and key_name.isdigit(): + key_name = int(key_name) + return NestedKeyMetadataEntry( - nested_key_name=json_dict[_KEY_NAME], + nested_key_name=key_name, key_type=KeyType.from_json(json_dict[_KEY_TYPE]), ) @@ -145,7 +150,7 @@ def from_json( def build(cls, keypath: KeyPath) -> KeyMetadataEntry: return KeyMetadataEntry([ NestedKeyMetadataEntry( - str(tree_utils.get_key_name(k)), _get_key_metadata_type(k) + key_name if isinstance(key_name := tree_utils.get_key_name(k), int) else str(key_name), _get_key_metadata_type(k) ) for k in keypath ])