Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ New Features
- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree`
objects (:issue:`9790`, :issue:`9778`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- :py:class:`DataTree` now supports indexing by lists of paths, similar to
:py:class:`DataTree` (:pull:`10854`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- The ``h5netcdf`` engine has support for pseudo ``NETCDF4_CLASSIC`` files, meaning variables and attributes are cast to supported types. Note that the saved files won't be recognized as genuine ``NETCDF4_CLASSIC`` files until ``h5netcdf`` adds support with version 1.7.0. (:issue:`10676`, :pull:`10686`).
By `David Huard <https://github.com/huard>`_.

Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def _decode_variable_name(name):


def _iter_nc_groups(root, parent="/"):
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

parent = NodePath(parent)
parent = TreePath(parent)
yield str(parent)
for path, group in root.groups.items():
gpath = parent / path
Expand Down
10 changes: 5 additions & 5 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def open_groups_as_dict(
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath
from xarray.core.utils import close_on_error

# Keep this message for some versions
Expand All @@ -652,9 +652,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
parent = TreePath("/") / TreePath(group)
else:
parent = NodePath("/")
parent = TreePath("/")

manager = store._manager
groups_dict = {}
Expand All @@ -674,9 +674,9 @@ def open_groups_as_dict(
)

if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

# only warn if phony_dims exist in file
Expand Down
10 changes: 5 additions & 5 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def open_groups_as_dict(
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

filename_or_obj = _normalize_path(filename_or_obj)
store = NetCDF4DataStore.open(
Expand All @@ -850,9 +850,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
parent = TreePath("/") / TreePath(group)
else:
parent = NodePath("/")
parent = TreePath("/")

manager = store._manager
groups_dict = {}
Expand All @@ -871,9 +871,9 @@ def open_groups_as_dict(
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict
Expand Down
12 changes: 6 additions & 6 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def open_groups_as_dict(
verify=None,
user_charset=None,
) -> dict[str, Dataset]:
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

filename_or_obj = _normalize_path(filename_or_obj)
store = PydapDataStore.open(
Expand All @@ -325,9 +325,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = str(NodePath("/") / NodePath(group))
parent = str(TreePath("/") / TreePath(group))
else:
parent = str(NodePath("/"))
parent = str(TreePath("/"))

groups_dict = {}
group_names = [parent]
Expand Down Expand Up @@ -365,7 +365,7 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]:

Groups = group_fqn(store.ds)
group_names += [
str(NodePath(path_to_group) / NodePath(group))
str(TreePath(path_to_group) / TreePath(group))
for group, path_to_group in Groups.items()
]
for path_group in group_names:
Expand All @@ -384,9 +384,9 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]:
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict
Expand Down
16 changes: 8 additions & 8 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath
from xarray.core.types import ZarrWriteModes
from xarray.core.utils import (
FrozenDict,
Expand Down Expand Up @@ -1752,9 +1752,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = str(NodePath("/") / NodePath(group))
parent = str(TreePath("/") / TreePath(group))
else:
parent = str(NodePath("/"))
parent = str(TreePath("/"))

stores = ZarrStore.open_store(
filename_or_obj,
Expand Down Expand Up @@ -1785,18 +1785,18 @@ def open_groups_as_dict(
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds
return groups_dict


def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
parent_nodepath = NodePath(parent)
yield str(parent_nodepath)
parent_TreePath = TreePath(parent)
yield str(parent_TreePath)
for path, group in root.groups():
gpath = parent_nodepath / path
gpath = parent_TreePath / path
yield from _iter_zarr_groups(group, parent=str(gpath))


Expand Down
72 changes: 49 additions & 23 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from xarray.core.indexes import Index, Indexes
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.options import _get_keep_attrs
from xarray.core.treenode import NamedNode, NodePath, zip_subtrees
from xarray.core.treenode import NamedNode, TreePath, zip_subtrees
from xarray.core.types import Self
from xarray.core.utils import (
Default,
Expand Down Expand Up @@ -114,7 +114,7 @@
# """


T_Path = Union[str, NodePath]
T_Path = Union[str, TreePath]
T = TypeVar("T")
P = ParamSpec("P")

Expand Down Expand Up @@ -188,7 +188,7 @@ def check_alignment(
base_ds = node_ds

for child_name, child in children.items():
child_path = str(NodePath(path) / child_name)
child_path = str(TreePath(path) / child_name)
child_ds = child.to_dataset(inherit=False)
check_alignment(child_path, child_ds, base_ds, child.children)

Expand Down Expand Up @@ -566,7 +566,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
raise KeyError(
f"parent {parent.name} already contains a variable named {name}"
)
path = str(NodePath(parent.path) / name)
path = str(TreePath(parent.path) / name)
node_ds = self.to_dataset(inherit=False)
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherit=True)
check_alignment(path, node_ds, parent_ds, self.children)
Expand Down Expand Up @@ -943,17 +943,48 @@ def get( # type: ignore[override]
else:
return default

def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
def _copy_listed(self, keys: list[str]) -> Self:
"""Get multiple items as a DataTree."""
base = TreePath(self.path)
nodes: dict[str, DataTree | Dataset] = {}
keys_by_node: defaultdict[str, list[str]] = defaultdict(list)
for key in keys:
path = base / key
assert path.root
_, *parts, name = path.parts
current_node = self.root
for part in parts:
current_node = current_node.children[part]
if name in current_node.children:
target = str(path.relative_to(base))
nodes[target] = current_node.children[name]
elif name in current_node.variables:
target = str(base.joinpath(*parts).relative_to(base))
keys_by_node[target].append(name) # DataArray
else:
raise KeyError(key)
for target, names in keys_by_node.items():
nodes[target] = self.root[target].dataset[names]
return self.from_dict(nodes, name=self.name)

@overload
def __getitem__(self, key: list[str]) -> Self: ...

@overload
def __getitem__(self, key: str) -> Self | DataArray: ...

def __getitem__(self, key: str | list[str]) -> Self | DataArray:
"""
Access child nodes, variables, or coordinates stored anywhere in this tree.

Returned object will be either a DataTree or DataArray object depending on whether the key given points to a
child or variable.
Returned object will be either a DataTree or DataArray object depending on
whether the key given points to a child or variable.

Parameters
----------
key : str
Name of variable / child within this node, or unix-like path to variable / child within another node.
Name of variable / child within this node, unix-like path to variable
/ child within another node, or a list of names/paths.

Returns
-------
Expand All @@ -967,14 +998,9 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
elif isinstance(key, str):
# TODO should possibly deal with hashables in general?
# path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
return self._get_item(path)
elif utils.is_list_like(key):
# iterable of variable names
raise NotImplementedError(
"Selecting via tags is deprecated, and selecting multiple items should be "
"implemented via .subset"
)
return self._get_item(key)
elif isinstance(key, list):
return self._copy_listed(key)
else:
raise ValueError(f"Invalid format for key: {key}")

Expand Down Expand Up @@ -1015,7 +1041,7 @@ def __setitem__(
elif isinstance(key, str):
# TODO should possibly deal with hashables in general?
# path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
path = TreePath(key)
if isinstance(value, Dataset):
value = DataTree(dataset=value)
return self._set_item(path, value, new_nodes_along_path=True)
Expand Down Expand Up @@ -1341,17 +1367,17 @@ def from_dict(
data_items,
((k, _CoordWrapper(v)) for k, v in coords_items),
)
nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {}
nodes: dict[TreePath, _CoordWrapper | FromDictDataValue] = {}
for key, value in flat_data_and_coords:
path = NodePath(key).absolute()
path = TreePath(key).absolute()
if path in nodes:
raise ValueError(
f"multiple entries found corresponding to node {str(path)!r}"
)
nodes[path] = value

# Merge nodes corresponding to DataArrays into Datasets
dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs)
dataset_args: defaultdict[TreePath, _DatasetArgs] = defaultdict(_DatasetArgs)
for path in list(nodes):
node = nodes[path]
if node is not None and not isinstance(node, Dataset | DataTree):
Expand All @@ -1378,7 +1404,7 @@ def from_dict(
) from e

# Create the root node
root_data = nodes.pop(NodePath("/"), None)
root_data = nodes.pop(TreePath("/"), None)
if isinstance(root_data, cls):
# use cls so type-checkers understand this method returns Self
obj = root_data.copy()
Expand All @@ -1391,7 +1417,7 @@ def from_dict(
f"or DataTree, got {type(root_data)}"
)

def depth(item: tuple[NodePath, object]) -> int:
def depth(item: tuple[TreePath, object]) -> int:
node_path, _ = item
return len(node_path.parts)

Expand Down Expand Up @@ -1745,7 +1771,7 @@ def match(self, pattern: str) -> DataTree:
matching_nodes = {
path: node.dataset
for path, node in self.subtree_with_keys
if NodePath(node.path).match(pattern)
if TreePath(node.path).match(pattern)
}
return DataTree.from_dict(matching_nodes, name=self.name)

Expand Down
Loading
Loading