Skip to content

Commit 52149db

Browse files
committed
Enable roundtripping nested dtypes through parquet and arrow
1 parent 82fa271 commit 52149db

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,7 @@ ExtensionArray
12211221
- Bug in :meth:`api.types.is_datetime64_any_dtype` where a custom :class:`ExtensionDtype` would return ``False`` for array-likes (:issue:`57055`)
12221222
- Bug in comparison between object with :class:`ArrowDtype` and incompatible-dtyped (e.g. string vs bool) incorrectly raising instead of returning all-``False`` (for ``==``) or all-``True`` (for ``!=``) (:issue:`59505`)
12231223
- Bug in constructing pandas data structures when passing into ``dtype`` a string of the type followed by ``[pyarrow]`` while PyArrow is not installed would raise ``NameError`` rather than ``ImportError`` (:issue:`57928`)
1224+
- Bug in dtype inference when roundtripping nested arrow dtypes like ``list``, ``struct``, ``map`` through pyarrow tables or parquet (:issue:`61529`)
12241225
- Bug in various :class:`DataFrame` reductions for pyarrow temporal dtypes returning incorrect dtype when result was null (:issue:`59234`)
12251226
- Fixed flex arithmetic with :class:`ExtensionArray` operands raising when ``fill_value`` was passed. (:issue:`62467`)
12261227

pandas/core/dtypes/dtypes.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2387,16 +2387,80 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
23872387
except (NotImplementedError, ValueError):
23882388
# Fall through to raise with nice exception message below
23892389
pass
2390+
binary_pattern = re.compile(r"^fixed_size_binary\[(?P<width>\d+)\]$")
2391+
if match := binary_pattern.match(base_type):
2392+
byte_width = match.group("width")
2393+
return cls(pa.binary(int(byte_width)))
23902394

23912395
raise NotImplementedError(
23922396
"Passing pyarrow type specific parameters "
23932397
f"({has_parameters.group()}) in the string is not supported. "
23942398
"Please construct an ArrowDtype object with a pyarrow_dtype "
23952399
"instance with specific parameters."
23962400
) from err
2397-
raise TypeError(f"'{base_type}' is not a valid pyarrow data type.") from err
2401+
# match maps
2402+
map_pattern = re.compile(r"^map<(?P<key>[^,<>]+),\s(?P<value>[^,<>]+)>$")
2403+
# match lists
2404+
list_inner_pattern = r"<item:\s(?P<item_type>.+)>$"
2405+
list_pattern = re.compile(rf"^list{list_inner_pattern}")
2406+
large_list_pattern = re.compile(rf"^large_list{list_inner_pattern}")
2407+
# match structs
2408+
struct_pattern = re.compile(r"^struct<(?P<fields>.+)>$")
2409+
if match := map_pattern.match(base_type):
2410+
pa_dtype = pa.map_(
2411+
pa.type_for_alias(match.group("key")),
2412+
pa.type_for_alias(match.group("value")),
2413+
)
2414+
elif match := list_pattern.match(base_type):
2415+
pa_dtype = pa.list_(
2416+
cls._resolve_inner_types(match.group("item_type") + "[pyarrow]")
2417+
)
2418+
elif match := large_list_pattern.match(base_type):
2419+
pa_dtype = pa.large_list(
2420+
cls._resolve_inner_types(match.group("item_type") + "[pyarrow]")
2421+
)
2422+
elif match := struct_pattern.match(base_type):
2423+
fields = []
2424+
for name, t in cls._split_struct(match.group("fields")):
2425+
field_dtype = cls._resolve_inner_types(t + "[pyarrow]")
2426+
fields.append((name, field_dtype))
2427+
pa_dtype = pa.struct(fields)
2428+
else:
2429+
raise TypeError(
2430+
f"'{base_type}' is not a valid pyarrow data type."
2431+
) from err
23982432
return cls(pa_dtype)
23992433

2434+
@classmethod
2435+
def _resolve_inner_types(cls, string: str) -> pa.DataType:
2436+
if string == "string[pyarrow]":
2437+
return pa.string()
2438+
else:
2439+
return cls.construct_from_string(string).pyarrow_dtype
2440+
2441+
@staticmethod
2442+
def _split_struct(fields: str):
2443+
field_pattern = re.compile(r"^\s*(?P<name>[^:]+):\s*(?P<type>.+)\s*$")
2444+
2445+
parts, start, depth = [], 0, 0
2446+
for i, char in enumerate(fields):
2447+
if char in "<":
2448+
depth += 1
2449+
elif char in ">":
2450+
depth -= 1
2451+
elif char == "," and depth == 0:
2452+
parts.append(fields[start:i].strip())
2453+
start = i + 1
2454+
2455+
if start < len(fields):
2456+
parts.append(fields[start:].strip())
2457+
2458+
for field in parts:
2459+
if match := field_pattern.match(field):
2460+
yield match.group("name"), match.group("type")
2461+
else:
2462+
raise TypeError(f"Could not parse struct field definition: '{field}'")
2463+
24002464
# TODO(arrow#33642): This can be removed once supported by pyarrow
24012465
@classmethod
24022466
def _parse_temporal_dtype_string(cls, string: str) -> ArrowDtype:

pandas/tests/extension/test_arrow.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3780,6 +3780,83 @@ def test_arrow_dtype_itemsize_fixed_width(type_name, expected_size):
37803780
)
37813781

37823782

3783+
def test_roundtrip_of_nested_types():
3784+
df = pd.DataFrame(
3785+
{
3786+
"list_int": pd.Series(
3787+
[[1, 2, 3], [4, 5]], dtype=ArrowDtype(pa.list_(pa.int64()))
3788+
),
3789+
"list_string": pd.Series(
3790+
[["a", "b"], ["c"]], dtype=ArrowDtype(pa.list_(pa.string()))
3791+
),
3792+
"large_list_int": pd.Series(
3793+
[[1, 2], [3, 4, 5]], dtype=ArrowDtype(pa.large_list(pa.int64()))
3794+
),
3795+
"large_list_string": pd.Series(
3796+
[["x", "y"], ["z"]], dtype=ArrowDtype(pa.large_list(pa.string()))
3797+
),
3798+
"list_map": pd.Series(
3799+
[[{"a": 1.0, "b": 2.0}], [{"c": 3.0}]],
3800+
dtype=ArrowDtype(pa.list_(pa.map_(pa.string(), pa.float64()))),
3801+
),
3802+
"large_list_map": pd.Series(
3803+
[[{"x": 1.5}], [{"y": 2.5, "z": 3.5}]],
3804+
dtype=ArrowDtype(pa.large_list(pa.map_(pa.string(), pa.float64()))),
3805+
),
3806+
"map_int_float": pd.Series(
3807+
[{1: 1.1, 2: 2.2}, {3: 3.3}],
3808+
dtype=ArrowDtype(pa.map_(pa.int64(), pa.float64())),
3809+
),
3810+
"struct_simple": pd.Series(
3811+
[{"f1": 1, "f2": 1.5}, {"f1": 2, "f2": 2.5}],
3812+
dtype=ArrowDtype(pa.struct([("f1", pa.int64()), ("f2", pa.float64())])),
3813+
),
3814+
"struct_nested": pd.Series(
3815+
[
3816+
{
3817+
"outer_int": 10,
3818+
"inner": {"int_list": [1, 2, 3], "text": "hello"},
3819+
},
3820+
{"outer_int": 20, "inner": {"int_list": [4, 5], "text": "world"}},
3821+
],
3822+
dtype=ArrowDtype(
3823+
pa.struct(
3824+
[
3825+
("outer_int", pa.int64()),
3826+
(
3827+
"inner",
3828+
pa.struct(
3829+
[
3830+
("int_list", pa.list_(pa.int64())),
3831+
("text", pa.string()),
3832+
]
3833+
),
3834+
),
3835+
]
3836+
)
3837+
),
3838+
),
3839+
"binary_16": pd.Series(
3840+
[b"0123456789abcdef", b"fedcba9876543210"],
3841+
dtype=ArrowDtype(pa.binary(16)),
3842+
),
3843+
"list_struct": pd.Series(
3844+
[
3845+
[{"id": 1, "value": 10.5}, {"id": 2, "value": 20.5}],
3846+
[{"id": 3, "value": 30.5}],
3847+
],
3848+
dtype=ArrowDtype(
3849+
pa.list_(pa.struct([("id", pa.int64()), ("value", pa.float64())]))
3850+
),
3851+
),
3852+
}
3853+
)
3854+
3855+
table = pa.Table.from_pandas(df)
3856+
result = table.to_pandas()
3857+
tm.assert_frame_equal(result, df)
3858+
3859+
37833860
@pytest.mark.parametrize("type_name", ["string", "binary", "large_string"])
37843861
def test_arrow_dtype_itemsize_variable_width(type_name):
37853862
# GH 57948

0 commit comments

Comments
 (0)