From 6193237c21fffdcc8966932a91e2cc6a6817f317 Mon Sep 17 00:00:00 2001 From: Emmanuel-Arokiaraj Date: Sun, 23 Nov 2025 19:15:03 +0530 Subject: [PATCH 1/2] fix(metadata): preserve integer dict keys when restoring from JSON (#2561) Integer-based PyTree keys were being converted to strings during metadata round-trip operations, causing incorrect key restoration on load. This change ensures numeric key names are serialized and deserialized correctly by converting digit-only strings back into integers. Adds KeyMetadataEntry.build and NestedKeyMetadataEntry.from_json fixes and corresponding tests verifying correct integer key reconstruction. --- .../orbax/checkpoint/_src/metadata/tree.py | 9 +++- .../export_type_conversions_check.py | 49 +++++++++++++++++++ checkpoint/orbax/checkpoint/metadata/tree.py | 2 +- 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index 8660b9c6d..8b0730f6c 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -116,8 +116,13 @@ def to_json(self) -> Dict[str, Union[str, int]]: def from_json( cls, json_dict: Dict[str, Union[str, int]] ) -> NestedKeyMetadataEntry: + + key_name = json_dict[_KEY_NAME] + if isinstance(key_name, str) and key_name.isdigit(): + key_name = int(key_name) + return NestedKeyMetadataEntry( - nested_key_name=json_dict[_KEY_NAME], + nested_key_name=key_name, key_type=KeyType.from_json(json_dict[_KEY_TYPE]), ) @@ -145,7 +150,7 @@ def from_json( def build(cls, keypath: KeyPath) -> KeyMetadataEntry: return KeyMetadataEntry([ NestedKeyMetadataEntry( - str(tree_utils.get_key_name(k)), _get_key_metadata_type(k) + key_name if isinstance(key_name := tree_utils.get_key_name(k), int) else str(key_name), _get_key_metadata_type(k) ) for k in keypath ]) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py new file mode 100644 index 000000000..699382b0a --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py @@ -0,0 +1,49 @@ +import jax +import jax.numpy as jnp +import pytest + +from orbax.checkpoint.metadata.tree import ( + KeyMetadataEntry, + NestedKeyMetadataEntry +) +from orbax.checkpoint.metadata import tree_utils + + +class TestExportTypeConversions: + """ + Integration test using real JAX PyTrees. + Ensures that numeric array-axis keys serialize and deserialize correctly. + """ + + def test_export_type_conversions(self): + """ + Create a JAX PyTree, extract key metadata, serialize and deserialize, + and ensure numeric string keys are converted back into integers. + """ + # Create a sample JAX structure + pytree = { + "layer": [ + jnp.array([1, 2, 3]), # index 0 → numeric key + jnp.array([4, 5, 6]) # index 1 → numeric key + ] + } + + # Build the keypath from the pytree + keypaths = list(tree_utils.flatten_with_path(pytree)) + keypath = keypaths[0][0] # take first path e.g. ("layer", 0) + + # Build metadata entry for that keypath + metadata_entry = KeyMetadataEntry.build(keypath) + + # Serialize to JSON + json_data = metadata_entry.to_json() + + # Deserialize back + restored = KeyMetadataEntry.from_json(json_data) + + # Extract nested entry for the second level (numeric key index) + numeric_entry: NestedKeyMetadataEntry = restored.nested_key_metadata_entries[1] + + # Assertions validating type conversion behavior + assert isinstance(numeric_entry.nested_key_name, int) + assert numeric_entry.nested_key_name == 0 \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/metadata/tree.py b/checkpoint/orbax/checkpoint/metadata/tree.py index 72243bf8b..311ca034b 100644 --- a/checkpoint/orbax/checkpoint/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/metadata/tree.py @@ -19,4 +19,4 @@ # The following symbols are provided for legacy use and WILL be removed in the # future. Please do not use. -from orbax.checkpoint._src.metadata.tree import ValueMetadataEntry +from orbax.checkpoint._src.metadata.tree import ValueMetadataEntry, NestedKeyMetadataEntry From 8b2eb6ccc1cd012c2a3c4f69d6f1076080888005 Mon Sep 17 00:00:00 2001 From: Emmanuel-Arokiaraj Date: Sun, 23 Nov 2025 19:24:56 +0530 Subject: [PATCH 2/2] removed testing modules written for fix --- .../export_type_conversions_check.py | 49 ------------------- checkpoint/orbax/checkpoint/metadata/tree.py | 2 +- 2 files changed, 1 insertion(+), 50 deletions(-) delete mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py deleted file mode 100644 index 699382b0a..000000000 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/export_type_conversions_check.py +++ /dev/null @@ -1,49 +0,0 @@ -import jax -import jax.numpy as jnp -import pytest - -from orbax.checkpoint.metadata.tree import ( - KeyMetadataEntry, - NestedKeyMetadataEntry -) -from orbax.checkpoint.metadata import tree_utils - - -class TestExportTypeConversions: - """ - Integration test using real JAX PyTrees. - Ensures that numeric array-axis keys serialize and deserialize correctly. - """ - - def test_export_type_conversions(self): - """ - Create a JAX PyTree, extract key metadata, serialize and deserialize, - and ensure numeric string keys are converted back into integers. - """ - # Create a sample JAX structure - pytree = { - "layer": [ - jnp.array([1, 2, 3]), # index 0 → numeric key - jnp.array([4, 5, 6]) # index 1 → numeric key - ] - } - - # Build the keypath from the pytree - keypaths = list(tree_utils.flatten_with_path(pytree)) - keypath = keypaths[0][0] # take first path e.g. ("layer", 0) - - # Build metadata entry for that keypath - metadata_entry = KeyMetadataEntry.build(keypath) - - # Serialize to JSON - json_data = metadata_entry.to_json() - - # Deserialize back - restored = KeyMetadataEntry.from_json(json_data) - - # Extract nested entry for the second level (numeric key index) - numeric_entry: NestedKeyMetadataEntry = restored.nested_key_metadata_entries[1] - - # Assertions validating type conversion behavior - assert isinstance(numeric_entry.nested_key_name, int) - assert numeric_entry.nested_key_name == 0 \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/metadata/tree.py b/checkpoint/orbax/checkpoint/metadata/tree.py index 311ca034b..72243bf8b 100644 --- a/checkpoint/orbax/checkpoint/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/metadata/tree.py @@ -19,4 +19,4 @@ # The following symbols are provided for legacy use and WILL be removed in the # future. Please do not use. -from orbax.checkpoint._src.metadata.tree import ValueMetadataEntry, NestedKeyMetadataEntry +from orbax.checkpoint._src.metadata.tree import ValueMetadataEntry