Skip to content

Commit 370eb8b

Browse files
committed
Fix fill_value serialization of NaN; add property-based tests
1 parent e8bfb64 commit 370eb8b

File tree

4 files changed

+278
-14
lines changed

4 files changed

+278
-14
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from zarr.core.common import ChunkCoords
2020

2121
import json
22+
import numbers
2223
from dataclasses import dataclass, field, fields, replace
2324

2425
import numcodecs
@@ -149,18 +150,20 @@ def _json_convert(
149150
json_indent = config.get("json_indent")
150151
return {
151152
ZARRAY_JSON: prototype.buffer.from_bytes(
152-
json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode()
153+
json.dumps(
154+
zarray_dict, default=_json_convert, indent=json_indent, allow_nan=False
155+
).encode()
153156
),
154157
ZATTRS_JSON: prototype.buffer.from_bytes(
155-
json.dumps(zattrs_dict, indent=json_indent).encode()
158+
json.dumps(zattrs_dict, indent=json_indent, allow_nan=False).encode()
156159
),
157160
}
158161

159162
@classmethod
160163
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
161-
# make a copy to protect the original from modification
164+
# Make a copy to protect the original from modification.
162165
_data = data.copy()
163-
# check that the zarr_format attribute is correct
166+
# Check that the zarr_format attribute is correct.
164167
_ = parse_zarr_format(_data.pop("zarr_format"))
165168
dtype = parse_dtype(_data["dtype"])
166169

@@ -169,20 +172,46 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
169172
if fill_value_encoded is not None:
170173
fill_value = base64.standard_b64decode(fill_value_encoded)
171174
_data["fill_value"] = fill_value
172-
173-
# zarr v2 allowed arbitrary keys here.
174-
# We don't want the ArrayV2Metadata constructor to fail just because someone put an
175-
# extra key in the metadata.
175+
else:
176+
fill_value = _data.get("fill_value")
177+
if fill_value is not None:
178+
if np.issubdtype(dtype, np.datetime64):
179+
if fill_value == "NaT":
180+
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
181+
else:
182+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183+
elif dtype.kind == "c" and isinstance(fill_value, list):
184+
if len(fill_value) == 2:
185+
val = complex(float(fill_value[0]), float(fill_value[1]))
186+
_data["fill_value"] = np.array(val, dtype=dtype)[()]
187+
elif dtype.kind in "f" and isinstance(fill_value, str):
188+
if fill_value in {"NaN", "Infinity", "-Infinity"}:
189+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
190+
# zarr v2 allowed arbitrary keys in the metadata.
191+
# Filter the keys to only those expected by the constructor.
176192
expected = {x.name for x in fields(cls)}
177-
# https://github.com/zarr-developers/zarr-python/issues/2269
178-
# handle the renames
179193
expected |= {"dtype", "chunks"}
180-
181194
_data = {k: v for k, v in _data.items() if k in expected}
182195

183196
return cls(**_data)
184197

185198
def to_dict(self) -> dict[str, JSON]:
199+
def _sanitize_fill_value(fv: Any):
200+
if fv is None:
201+
return fv
202+
elif isinstance(fv, np.datetime64):
203+
if np.isnat(fv):
204+
return "NaT"
205+
return np.datetime_as_string(fv)
206+
elif isinstance(fv, numbers.Real):
207+
if np.isnan(fv):
208+
fv = "NaN"
209+
elif np.isinf(fv):
210+
fv = "Infinity" if fv > 0 else "-Infinity"
211+
elif isinstance(fv, numbers.Complex):
212+
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
213+
return fv
214+
186215
zarray_dict = super().to_dict()
187216

188217
if self.dtype.kind in "SV" and self.fill_value is not None:
@@ -192,6 +221,7 @@ def to_dict(self) -> dict[str, JSON]:
192221
fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii")
193222
zarray_dict["fill_value"] = fill_value
194223

224+
zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"])
195225
_ = zarray_dict.pop("dtype")
196226
dtype_json: JSON
197227
# In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string
@@ -300,7 +330,6 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any:
300330
-------
301331
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
302332
"""
303-
304333
if fill_value is None or dtype.hasobject:
305334
# no fill value
306335
pass

src/zarr/testing/stateful.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None:
8585
@rule(
8686
data=st.data(),
8787
name=node_names,
88-
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
88+
array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))),
8989
)
9090
def add_array(
9191
self,

src/zarr/testing/strategies.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
6+
import numcodecs
67
import numpy as np
7-
from hypothesis import given, settings # noqa: F401
8+
from hypothesis import assume, given, settings # noqa: F401
89
from hypothesis.strategies import SearchStrategy
910

1011
import zarr
@@ -344,3 +345,136 @@ def make_request(start: int, length: int) -> RangeByteRequest:
344345
)
345346
key_tuple = st.tuples(keys, byte_ranges)
346347
return st.lists(key_tuple, min_size=1, max_size=10)
348+
349+
350+
def simple_text():
351+
"""A strategy for generating simple text strings."""
352+
return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10)
353+
354+
355+
def simple_attrs():
356+
"""A strategy for generating simple attribute dictionaries."""
357+
return st.dictionaries(
358+
simple_text(),
359+
st.one_of(
360+
st.integers(),
361+
st.floats(allow_nan=False, allow_infinity=False),
362+
st.booleans(),
363+
simple_text(),
364+
),
365+
)
366+
367+
368+
def array_shapes(min_dims=1, max_dims=3, max_len=100):
369+
"""A strategy for generating array shapes."""
370+
return st.lists(
371+
st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims
372+
)
373+
374+
375+
# def zarr_compressors():
376+
# """A strategy for generating Zarr compressors."""
377+
# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()])
378+
379+
380+
# def zarr_codecs():
381+
# """A strategy for generating Zarr codecs."""
382+
# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()])
383+
384+
385+
def zarr_filters():
386+
"""A strategy for generating Zarr filters."""
387+
return st.lists(
388+
st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2
389+
) # Example filter, expand as needed
390+
391+
392+
def zarr_storage_transformers():
393+
"""A strategy for generating Zarr storage transformers."""
394+
return st.lists(
395+
st.dictionaries(
396+
simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text())
397+
),
398+
min_size=0,
399+
max_size=2,
400+
)
401+
402+
403+
@st.composite
404+
def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata:
405+
"""Generates valid ArrayV2Metadata objects for property-based testing."""
406+
dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity
407+
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
408+
max_chunk_len = max(shape) if shape else 100
409+
chunks = tuple(
410+
draw(
411+
st.lists(
412+
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
413+
)
414+
)
415+
)
416+
417+
# Validate shape and chunks relationship
418+
assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape
419+
420+
dtype = draw(v2_dtypes())
421+
fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)]))
422+
order = draw(st.sampled_from(["C", "F"]))
423+
dimension_separator = draw(st.sampled_from([".", "/"]))
424+
# compressor = draw(zarr_compressors())
425+
filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None
426+
attributes = draw(simple_attrs())
427+
428+
# Construct the metadata object. Type hints are crucial here for correctness.
429+
return ArrayV2Metadata(
430+
shape=shape,
431+
dtype=dtype,
432+
chunks=chunks,
433+
fill_value=fill_value,
434+
order=order,
435+
dimension_separator=dimension_separator,
436+
# compressor=compressor,
437+
filters=filters,
438+
attributes=attributes,
439+
)
440+
441+
442+
@st.composite
443+
def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata:
444+
"""Generates valid ArrayV3Metadata objects for property-based testing."""
445+
dims = draw(st.integers(min_value=1, max_value=3))
446+
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
447+
max_chunk_len = max(shape) if shape else 100
448+
chunks = tuple(
449+
draw(
450+
st.lists(
451+
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
452+
)
453+
)
454+
)
455+
assume(all(c <= s for s, c in zip(shape, chunks, strict=False)))
456+
457+
dtype = draw(v3_dtypes())
458+
fill_value = draw(npst.from_dtype(dtype))
459+
chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple.
460+
chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."])
461+
# codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3)))
462+
attributes = draw(simple_attrs())
463+
dimension_names = (
464+
tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims)))
465+
if draw(st.booleans())
466+
else None
467+
)
468+
storage_transformers = tuple(draw(zarr_storage_transformers()))
469+
470+
return ArrayV3Metadata(
471+
shape=shape,
472+
data_type=dtype,
473+
chunk_grid=chunk_grid,
474+
chunk_key_encoding=chunk_key_encoding,
475+
fill_value=fill_value,
476+
# codecs=codecs,
477+
attributes=attributes,
478+
dimension_names=dimension_names,
479+
storage_transformers=storage_transformers,
480+
)

tests/test_properties.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import dataclasses
2+
import json
3+
4+
import numpy as np
15
import pytest
26
from numpy.testing import assert_array_equal
37

@@ -10,10 +14,12 @@
1014
from hypothesis import assume, given
1115

1216
from zarr.abc.store import Store
17+
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON
1318
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
1419
from zarr.core.sync import sync
1520
from zarr.testing.strategies import (
1621
array_metadata,
22+
array_metadata_v2,
1723
arrays,
1824
basic_indices,
1925
numpy_arrays,
@@ -23,6 +29,60 @@
2329
)
2430

2531

32+
def deep_equal(a, b):
33+
"""Deep equality check w/ NaN e to handle array metadata serialization and deserialization behaviors"""
34+
if isinstance(a, (complex, np.complexfloating)) and isinstance(
35+
b, (complex, np.complexfloating)
36+
):
37+
# Convert to Python float to force standard NaN handling.
38+
a_real, a_imag = float(a.real), float(a.imag)
39+
b_real, b_imag = float(b.real), float(b.imag)
40+
# If both parts are NaN, consider them equal.
41+
if np.isnan(a_real) and np.isnan(b_real):
42+
real_eq = True
43+
else:
44+
real_eq = a_real == b_real
45+
if np.isnan(a_imag) and np.isnan(b_imag):
46+
imag_eq = True
47+
else:
48+
imag_eq = a_imag == b_imag
49+
return real_eq and imag_eq
50+
51+
# Handle floats (including numpy floating types) and treat NaNs as equal.
52+
if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)):
53+
if np.isnan(a) and np.isnan(b):
54+
return True
55+
return a == b
56+
57+
# Handle numpy.datetime64 values, treating NaT as equal.
58+
if isinstance(a, np.datetime64) and isinstance(b, np.datetime64):
59+
if np.isnat(a) and np.isnat(b):
60+
return True
61+
return a == b
62+
63+
# Handle numpy arrays.
64+
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
65+
if a.shape != b.shape:
66+
return False
67+
# Compare elementwise.
68+
return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False))
69+
70+
# Handle dictionaries.
71+
if isinstance(a, dict) and isinstance(b, dict):
72+
if set(a.keys()) != set(b.keys()):
73+
return False
74+
return all(deep_equal(a[k], b[k]) for k in a)
75+
76+
# Handle lists and tuples.
77+
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
78+
if len(a) != len(b):
79+
return False
80+
return all(deep_equal(x, y) for x, y in zip(a, b, strict=False))
81+
82+
# Fallback to default equality.
83+
return a == b
84+
85+
2686
@given(data=st.data(), zarr_format=zarr_formats)
2787
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
2888
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
@@ -135,3 +195,44 @@ async def test_roundtrip_array_metadata(
135195
# nparray = data.draw(np_arrays)
136196
# zarray = data.draw(arrays(arrays=st.just(nparray)))
137197
# assert_array_equal(nparray, zarray[:])
198+
199+
200+
@given(array_metadata_v2())
201+
def test_v2meta_roundtrip(metadata):
202+
buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype())
203+
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
204+
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
205+
206+
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
207+
zarray_dict["attributes"] = zattrs_dict
208+
209+
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
210+
211+
# Convert both metadata instances to dictionaries.
212+
orig = dataclasses.asdict(metadata)
213+
rt = dataclasses.asdict(metadata_roundtripped)
214+
215+
assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}"
216+
217+
218+
@given(npst.from_dtype(dtype=np.dtype("float64"), allow_nan=True, allow_infinity=True))
219+
def test_v2meta_nan_and_infinity(fill_value):
220+
metadata = ArrayV2Metadata(
221+
shape=[10],
222+
dtype=np.dtype("float64"),
223+
chunks=[5],
224+
fill_value=fill_value,
225+
order="C",
226+
)
227+
228+
buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype())
229+
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
230+
231+
if np.isnan(fill_value):
232+
assert zarray_dict["fill_value"] == "NaN"
233+
elif np.isinf(fill_value) and fill_value > 0:
234+
assert zarray_dict["fill_value"] == "Infinity"
235+
elif np.isinf(fill_value):
236+
assert zarray_dict["fill_value"] == "-Infinity"
237+
else:
238+
assert zarray_dict["fill_value"] == fill_value

0 commit comments

Comments
 (0)