diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0497c4a031f..1342568908c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -16,6 +16,9 @@ New Features - :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree` objects (:issue:`9790`, :issue:`9778`). By `Stephan Hoyer `_. +- :py:class:`DataTree` now supports indexing by lists of paths, similar to + :py:class:`DataTree` (:pull:`10854`). + By `Stephan Hoyer `_. - The ``h5netcdf`` engine has support for pseudo ``NETCDF4_CLASSIC`` files, meaning variables and attributes are cast to supported types. Note that the saved files won't be recognized as genuine ``NETCDF4_CLASSIC`` files until ``h5netcdf`` adds support with version 1.7.0. (:issue:`10676`, :pull:`10686`). By `David Huard `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 92694c16a52..88c9ce0cae8 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -236,9 +236,9 @@ def _decode_variable_name(name): def _iter_nc_groups(root, parent="/"): - from xarray.core.treenode import NodePath + from xarray.core.treenode import TreePath - parent = NodePath(parent) + parent = TreePath(parent) yield str(parent) for path, group in root.groups.items(): gpath = parent / path diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 8967ae97802..84554ef08f5 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -630,7 +630,7 @@ def open_groups_as_dict( **kwargs, ) -> dict[str, Dataset]: from xarray.backends.common import _iter_nc_groups - from xarray.core.treenode import NodePath + from xarray.core.treenode import TreePath from xarray.core.utils import close_on_error # Keep this message for some versions @@ -652,9 +652,9 @@ def open_groups_as_dict( # Check for a group and make it a parent if it exists if group: - parent = NodePath("/") / NodePath(group) + parent = TreePath("/") / TreePath(group) else: - parent = NodePath("/") + parent = TreePath("/") manager = store._manager groups_dict = {} @@ -674,9 +674,9 @@ def open_groups_as_dict( ) if group: - group_name = str(NodePath(path_group).relative_to(parent)) + group_name = str(TreePath(path_group).relative_to(parent)) else: - group_name = str(NodePath(path_group)) + group_name = str(TreePath(path_group)) groups_dict[group_name] = group_ds # only warn if phony_dims exist in file diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 8d4ca6441c9..0669567339b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -833,7 +833,7 @@ def open_groups_as_dict( **kwargs, ) -> dict[str, Dataset]: from xarray.backends.common import _iter_nc_groups - from xarray.core.treenode import NodePath + from xarray.core.treenode import TreePath filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( @@ -850,9 +850,9 @@ def open_groups_as_dict( # Check for a group and make it a parent if it exists if group: - parent = NodePath("/") / NodePath(group) + parent = TreePath("/") / TreePath(group) else: - parent = NodePath("/") + parent = TreePath("/") manager = store._manager groups_dict = {} @@ -871,9 +871,9 @@ def open_groups_as_dict( decode_timedelta=decode_timedelta, ) if group: - group_name = str(NodePath(path_group).relative_to(parent)) + group_name = str(TreePath(path_group).relative_to(parent)) else: - group_name = str(NodePath(path_group)) + group_name = str(TreePath(path_group)) groups_dict[group_name] = group_ds return groups_dict diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 4fbfe8ee210..1f67072cc41 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -311,7 +311,7 @@ def open_groups_as_dict( verify=None, user_charset=None, ) -> dict[str, Dataset]: - from xarray.core.treenode import NodePath + from xarray.core.treenode import TreePath filename_or_obj = _normalize_path(filename_or_obj) store = PydapDataStore.open( @@ -325,9 +325,9 @@ def open_groups_as_dict( # Check for a group and make it a parent if it exists if group: - parent = str(NodePath("/") / NodePath(group)) + parent = str(TreePath("/") / TreePath(group)) else: - parent = str(NodePath("/")) + parent = str(TreePath("/")) groups_dict = {} group_names = [parent] @@ -365,7 +365,7 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]: Groups = group_fqn(store.ds) group_names += [ - str(NodePath(path_to_group) / NodePath(group)) + str(TreePath(path_to_group) / TreePath(group)) for group, path_to_group in Groups.items() ] for path_group in group_names: @@ -384,9 +384,9 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]: decode_timedelta=decode_timedelta, ) if group: - group_name = str(NodePath(path_group).relative_to(parent)) + group_name = str(TreePath(path_group).relative_to(parent)) else: - group_name = str(NodePath(path_group)) + group_name = str(TreePath(path_group)) groups_dict[group_name] = group_ds return groups_dict diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index fe004c212b6..4a15554e5ac 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -25,7 +25,7 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing -from xarray.core.treenode import NodePath +from xarray.core.treenode import TreePath from xarray.core.types import ZarrWriteModes from xarray.core.utils import ( FrozenDict, @@ -1752,9 +1752,9 @@ def open_groups_as_dict( # Check for a group and make it a parent if it exists if group: - parent = str(NodePath("/") / NodePath(group)) + parent = str(TreePath("/") / TreePath(group)) else: - parent = str(NodePath("/")) + parent = str(TreePath("/")) stores = ZarrStore.open_store( filename_or_obj, @@ -1785,18 +1785,18 @@ def open_groups_as_dict( decode_timedelta=decode_timedelta, ) if group: - group_name = str(NodePath(path_group).relative_to(parent)) + group_name = str(TreePath(path_group).relative_to(parent)) else: - group_name = str(NodePath(path_group)) + group_name = str(TreePath(path_group)) groups_dict[group_name] = group_ds return groups_dict def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]: - parent_nodepath = NodePath(parent) - yield str(parent_nodepath) + parent_TreePath = TreePath(parent) + yield str(parent_TreePath) for path, group in root.groups(): - gpath = parent_nodepath / path + gpath = parent_TreePath / path yield from _iter_zarr_groups(group, parent=str(gpath)) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 6fb387e3c15..33230864917 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -51,7 +51,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.options import OPTIONS as XR_OPTS from xarray.core.options import _get_keep_attrs -from xarray.core.treenode import NamedNode, NodePath, zip_subtrees +from xarray.core.treenode import NamedNode, TreePath, zip_subtrees from xarray.core.types import Self from xarray.core.utils import ( Default, @@ -114,7 +114,7 @@ # """ -T_Path = Union[str, NodePath] +T_Path = Union[str, TreePath] T = TypeVar("T") P = ParamSpec("P") @@ -188,7 +188,7 @@ def check_alignment( base_ds = node_ds for child_name, child in children.items(): - child_path = str(NodePath(path) / child_name) + child_path = str(TreePath(path) / child_name) child_ds = child.to_dataset(inherit=False) check_alignment(child_path, child_ds, base_ds, child.children) @@ -566,7 +566,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: raise KeyError( f"parent {parent.name} already contains a variable named {name}" ) - path = str(NodePath(parent.path) / name) + path = str(TreePath(parent.path) / name) node_ds = self.to_dataset(inherit=False) parent_ds = parent._to_dataset_view(rebuild_dims=False, inherit=True) check_alignment(path, node_ds, parent_ds, self.children) @@ -943,17 +943,48 @@ def get( # type: ignore[override] else: return default - def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: + def _copy_listed(self, keys: list[str]) -> Self: + """Get multiple items as a DataTree.""" + base = TreePath(self.path) + nodes: dict[str, DataTree | Dataset] = {} + keys_by_node: defaultdict[str, list[str]] = defaultdict(list) + for key in keys: + path = base / key + assert path.root + _, *parts, name = path.parts + current_node = self.root + for part in parts: + current_node = current_node.children[part] + if name in current_node.children: + target = str(path.relative_to(base)) + nodes[target] = current_node.children[name] + elif name in current_node.variables: + target = str(base.joinpath(*parts).relative_to(base)) + keys_by_node[target].append(name) # DataArray + else: + raise KeyError(key) + for target, names in keys_by_node.items(): + nodes[target] = self.root[target].dataset[names] + return self.from_dict(nodes, name=self.name) + + @overload + def __getitem__(self, key: list[str]) -> Self: ... + + @overload + def __getitem__(self, key: str) -> Self | DataArray: ... + + def __getitem__(self, key: str | list[str]) -> Self | DataArray: """ Access child nodes, variables, or coordinates stored anywhere in this tree. - Returned object will be either a DataTree or DataArray object depending on whether the key given points to a - child or variable. + Returned object will be either a DataTree or DataArray object depending on + whether the key given points to a child or variable. Parameters ---------- key : str - Name of variable / child within this node, or unix-like path to variable / child within another node. + Name of variable / child within this node, unix-like path to variable + / child within another node, or a list of names/paths. Returns ------- @@ -967,14 +998,9 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: elif isinstance(key, str): # TODO should possibly deal with hashables in general? # path-like: a name of a node/variable, or path to a node/variable - path = NodePath(key) - return self._get_item(path) - elif utils.is_list_like(key): - # iterable of variable names - raise NotImplementedError( - "Selecting via tags is deprecated, and selecting multiple items should be " - "implemented via .subset" - ) + return self._get_item(key) + elif isinstance(key, list): + return self._copy_listed(key) else: raise ValueError(f"Invalid format for key: {key}") @@ -1015,7 +1041,7 @@ def __setitem__( elif isinstance(key, str): # TODO should possibly deal with hashables in general? # path-like: a name of a node/variable, or path to a node/variable - path = NodePath(key) + path = TreePath(key) if isinstance(value, Dataset): value = DataTree(dataset=value) return self._set_item(path, value, new_nodes_along_path=True) @@ -1341,9 +1367,9 @@ def from_dict( data_items, ((k, _CoordWrapper(v)) for k, v in coords_items), ) - nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {} + nodes: dict[TreePath, _CoordWrapper | FromDictDataValue] = {} for key, value in flat_data_and_coords: - path = NodePath(key).absolute() + path = TreePath(key).absolute() if path in nodes: raise ValueError( f"multiple entries found corresponding to node {str(path)!r}" @@ -1351,7 +1377,7 @@ def from_dict( nodes[path] = value # Merge nodes corresponding to DataArrays into Datasets - dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs) + dataset_args: defaultdict[TreePath, _DatasetArgs] = defaultdict(_DatasetArgs) for path in list(nodes): node = nodes[path] if node is not None and not isinstance(node, Dataset | DataTree): @@ -1378,7 +1404,7 @@ def from_dict( ) from e # Create the root node - root_data = nodes.pop(NodePath("/"), None) + root_data = nodes.pop(TreePath("/"), None) if isinstance(root_data, cls): # use cls so type-checkers understand this method returns Self obj = root_data.copy() @@ -1391,7 +1417,7 @@ def from_dict( f"or DataTree, got {type(root_data)}" ) - def depth(item: tuple[NodePath, object]) -> int: + def depth(item: tuple[TreePath, object]) -> int: node_path, _ = item return len(node_path.parts) @@ -1745,7 +1771,7 @@ def match(self, pattern: str) -> DataTree: matching_nodes = { path: node.dataset for path, node in self.subtree_with_keys - if NodePath(node.path).match(pattern) + if TreePath(node.path).match(pattern) } return DataTree.from_dict(matching_nodes, name=self.name) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 90f817ed017..8434201e396 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -21,20 +21,65 @@ class NotFoundInTreeError(ValueError): """Raised when operation can't be completed because one node is not part of the expected tree.""" -class NodePath(PurePosixPath): +class TreePath(PurePosixPath): """Represents a path from one node to another within a tree.""" + def __new__(cls, *pathsegments): + if sys.version_info >= (3, 12): + unnormalized = super().__new__(cls) + else: + unnormalized = super().__new__(cls, *pathsegments) + unnormalized.__init__(*pathsegments) + + # TreePath does not support symlinks, so we can resolve segments like + # "." and ".." + if unnormalized.is_absolute(): + parts_without_root = unnormalized.parts[1:] + else: + parts_without_root = unnormalized.parts + parts = [] + for part in parts_without_root: + if part == "..": + if not parts or parts[-1] == "..": + if unnormalized.is_absolute(): + raise ValueError( + f"path accesses node before root: {unnormalized}" + ) + parts.append(part) + else: + parts.pop() # remove parent + elif part not in ("", "."): + parts.append(part) + + if unnormalized.is_absolute(): + parts = ["/"] + parts + + if sys.version_info >= (3, 12): + normalized = super().__new__(cls) + else: + normalized = super().__new__(cls, *parts) + normalized.__init__(*parts) + + return normalized + + # Implementing _from_parsed_parts is required for __div__ on Python 3.11 + if sys.version_info < (3, 12): + + @classmethod + def _from_parsed_parts(cls, drv, root, parts): + return cls(*parts) + def __init__(self, *pathsegments): if sys.version_info >= (3, 12): + # No __init__ method on base-class in Python 3.11 super().__init__(*pathsegments) - else: - super().__new__(PurePosixPath, *pathsegments) + if self.drive: - raise ValueError("NodePaths cannot have drives") + raise ValueError("TreePaths cannot have drives") if self.root not in ["/", ""]: raise ValueError( - 'Root of NodePath can only be either "/" or "", with "" meaning the path is relative.' + 'Root of TreePath can only be either "/" or "", with "" meaning the path is relative.' ) # TODO should we also forbid suffixes to avoid node names with dots in them? @@ -90,7 +135,8 @@ def parent(self) -> Self | None: @parent.setter def parent(self, new_parent: Self) -> None: raise AttributeError( - "Cannot set parent attribute directly, you must modify the children of the other node instead using dict-like syntax" + "Cannot set parent attribute directly, you must modify the children of " + "the other node instead using dict-like syntax" ) def _set_parent( @@ -425,7 +471,7 @@ def subtree_with_keys(self) -> Iterator[tuple[str, Self]]: DataTree.descendants group_subtrees """ - queue = collections.deque([(NodePath(), self)]) + queue = collections.deque([(TreePath(), self)]) while queue: path, node = queue.popleft() yield str(path), node @@ -529,39 +575,6 @@ def get(self, key: str, default: Self | None = None) -> Self | None: else: return default - # TODO `._walk` method to be called by both `_get_item` and `_set_item` - - def _get_item(self, path: str | NodePath) -> Self | DataArray: - """ - Returns the object lying at the given path. - - Raises a KeyError if there is no object at the given path. - """ - if isinstance(path, str): - path = NodePath(path) - - if path.root: - current_node = self.root - _root, *parts = list(path.parts) - else: - current_node = self - parts = list(path.parts) - - for part in parts: - if part == "..": - if current_node.parent is None: - raise KeyError(f"Could not find node at {path}") - else: - current_node = current_node.parent - elif part in ("", "."): - pass - else: - child = current_node.get(part) - if child is None: - raise KeyError(f"Could not find node at {path}") - current_node = child - return current_node - def _set(self, key: str, val: Any) -> None: """ Set the child node with the specified key to value. @@ -571,81 +584,6 @@ def _set(self, key: str, val: Any) -> None: new_children = {**self.children, key: val} self.children = new_children - def _set_item( - self, - path: str | NodePath, - item: Any, - new_nodes_along_path: bool = False, - allow_overwrite: bool = True, - ) -> None: - """ - Set a new item in the tree, overwriting anything already present at that path. - - The given value either forms a new node of the tree or overwrites an - existing item at that location. - - Parameters - ---------- - path - item - new_nodes_along_path : bool - If true, then if necessary new nodes will be created along the - given path, until the tree can reach the specified location. - allow_overwrite : bool - Whether or not to overwrite any existing node at the location given - by path. - - Raises - ------ - KeyError - If node cannot be reached, and new_nodes_along_path=False. - Or if a node already exists at the specified path, and allow_overwrite=False. - """ - if isinstance(path, str): - path = NodePath(path) - - if not path.name: - raise ValueError("Can't set an item under a path which has no name") - - if path.root: - # absolute path - current_node = self.root - _root, *parts, name = path.parts - else: - # relative path - current_node = self - *parts, name = path.parts - - if parts: - # Walk to location of new node, creating intermediate node objects as we go if necessary - for part in parts: - if part == "..": - if current_node.parent is None: - # We can't create a parent if `new_nodes_along_path=True` as we wouldn't know what to name it - raise KeyError(f"Could not reach node at path {path}") - else: - current_node = current_node.parent - elif part in ("", "."): - pass - elif part in current_node.children: - current_node = current_node.children[part] - elif new_nodes_along_path: - # Want child classes (i.e. DataTree) to populate tree with their own types - new_node = type(self)() - current_node._set(part, new_node) - current_node = current_node.children[part] - else: - raise KeyError(f"Could not reach node at path {path}") - - if name in current_node.children: - # Deal with anything already existing at this location - if allow_overwrite: - current_node._set(name, item) - else: - raise KeyError(f"Already a node object at path {path}") - else: - current_node._set(name, item) - def __delitem__(self, key: str) -> None: """Remove a child node from this tree object.""" if key in self.children: @@ -750,7 +688,7 @@ def relative_to(self, other: Self) -> str: "Cannot find relative path because nodes do not lie within the same tree" ) - this_path = NodePath(self.path) + this_path = TreePath(self.path) if any(other.path == parent.path for parent in (self, *self.parents)): return str(this_path.relative_to(other.path)) else: @@ -778,7 +716,7 @@ def find_common_ancestor(self, other: Self) -> Self: "Cannot find common ancestor because nodes do not lie within the same tree" ) - def _path_to_ancestor(self, ancestor: Self) -> NodePath: + def _path_to_ancestor(self, ancestor: Self) -> TreePath: """Return the relative path from this node to the given ancestor node""" if not self.same_tree(ancestor): @@ -793,7 +731,87 @@ def _path_to_ancestor(self, ancestor: Self) -> NodePath: parents_paths = [parent.path for parent in (self, *self.parents)] generation_gap = list(parents_paths).index(ancestor.path) path_upwards = "../" * generation_gap if generation_gap > 0 else "." - return NodePath(path_upwards) + return TreePath(path_upwards) + + # TODO `._walk` method to be called by both `_get_item` and `_set_item` + + def _get_item(self, path: str | TreePath) -> Self | DataArray: + """ + Returns the object lying at the given path. + + Raises a KeyError if there is no object at the given path. + """ + path = TreePath(self.path) / path + assert path.root + current_node = self.root + _, *parts = path.parts + for part in parts: + child = current_node.get(part) + if child is None: + raise KeyError(f"Could not find node at {path}") + current_node = child + return current_node + + def _set_item( + self, + path: str | TreePath, + item: Any, + new_nodes_along_path: bool = False, + allow_overwrite: bool = True, + ) -> None: + """ + Set a new item in the tree, overwriting anything already present at that path. + + The given value either forms a new node of the tree or overwrites an + existing item at that location. + + Parameters + ---------- + path + item + new_nodes_along_path : bool + If true, then if necessary new nodes will be created along the + given path, until the tree can reach the specified location. + allow_overwrite : bool + Whether or not to overwrite any existing node at the location given + by path. + + Raises + ------ + KeyError + If node cannot be reached, and new_nodes_along_path=False. + Or if a node already exists at the specified path, and allow_overwrite=False. + """ + path = TreePath(self.path) / path + + if not path.name: + raise ValueError("Can't set an item under a path which has no name") + + assert path.root + current_node = self.root + _, *parts, name = path.parts + + if parts: + # Walk to location of new node, creating intermediate node objects as we go if necessary + for part in parts: + if part in current_node.children: + current_node = current_node.children[part] + elif new_nodes_along_path: + # Want child classes (i.e. DataTree) to populate tree with their own types + new_node = type(self)() + current_node._set(part, new_node) + current_node = current_node.children[part] + else: + raise KeyError(f"Could not reach node at path {path}") + + if name in current_node.children: + # Deal with anything already existing at this location + if allow_overwrite: + current_node._set(name, item) + else: + raise KeyError(f"Already a node object at path {path}") + else: + current_node._set(name, item) class TreeIsomorphismError(ValueError): @@ -839,7 +857,7 @@ def group_subtrees( raise TypeError("must pass at least one tree object") # https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode - queue = collections.deque([(NodePath(), trees)]) + queue = collections.deque([(TreePath(), trees)]) while queue: path, active_nodes = queue.popleft() diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 0cd888f5782..749605c12e9 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -314,11 +314,39 @@ def test_getitem_nonexistent_variable(self) -> None: with pytest.raises(KeyError): results["pressure"] - @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") - def test_getitem_multiple_data_variables(self) -> None: - data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) - results = DataTree(name="results", dataset=data) - assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] + def test_getitem_multiple_variables(self) -> None: + data = DataTree.from_dict({"x": 0, "y": 1}) + + expected = DataTree.from_dict({"x": 0}) + actual = data[["x"]] + assert_identical(actual, expected) + + expected = data + actual = data[["x", "y"]] + assert_identical(actual, expected) + + def test_getitem_nested(self) -> None: + data = DataTree.from_dict({"a/b/c": 0, "a/d": 1, "e": 2}) + + expected = DataTree.from_dict({"a/b/c": 0, "a/d": 1}) + actual = data[["a"]] + assert_identical(actual, expected) + actual = data[["a/b", "a/d"]] + assert_identical(actual, expected) + + expected = DataTree.from_dict({"a/d": 1, "e": 2}) + actual = data[["a/d", "e"]] + assert_identical(actual, expected) + + expected = DataTree.from_dict({"e": 2}, name="a") + actual = data.children["a"][["../e"]] + assert_identical(actual, expected) + + with pytest.raises(KeyError, match="Already a node object at path /a/b"): + data[["a", "a/b"]] + + with pytest.raises(KeyError, match="'not_found'"): + data[["not_found"]] @pytest.mark.xfail( reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" @@ -326,7 +354,7 @@ def test_getitem_multiple_data_variables(self) -> None: def test_getitem_dict_like_selection_access_to_dataset(self) -> None: data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", dataset=data) - assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] + assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[call-overload] class TestUpdate: diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 7eb715630bb..36c66cc539b 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -7,13 +7,84 @@ from xarray.core.treenode import ( InvalidTreeError, NamedNode, - NodePath, TreeNode, + TreePath, group_subtrees, zip_subtrees, ) +class TestTreePath: + def test_treepath_simple(self): + path = TreePath("/Mary") + assert path.root == "/" + assert path.stem == "Mary" + + @pytest.mark.parametrize( + "path_str, expected", + [ + # Test cases with '.', '..', and simple paths + ("/", "/"), + ("/.", "/"), + (".", "."), + ("", "."), + ("..", ".."), + ("../", ".."), + ("foo", "foo"), + ("/foo", "/foo"), + ("foo/", "foo"), + ("/foo/", "/foo"), + # Test cases with '.' + ("/foo/.", "/foo"), + ("foo/./bar", "foo/bar"), + ("./foo/bar", "foo/bar"), + ("foo/bar/.", "foo/bar"), + ("/./", "/"), + ("./", "."), + # Test cases with '..' + ("foo/../bar", "bar"), + ("/foo/../bar", "/bar"), + ("/foo/bar/../..", "/"), + ("../bar", "../bar"), + ("/a/b/../c/./..", "/a"), + ("a/b/../../c", "c"), + ("/a/b/../../c", "/c"), + ("a/../b", "b"), + ("../../a", "../../a"), + ("a/b/../c/../d", "a/d"), + # Other cases + ("foo//bar", "foo/bar"), + ("/foo/../...", "/..."), + ], + ) + def test_treepath_normalization(self, path_str, expected): + assert str(TreePath(path_str)) == expected + + @pytest.mark.parametrize( + "path_str", + [ + "/..", + "/../", + "/foo/../..", + "/foo/../../bar", + "/a/b/../../../c", + ], + ) + def test_treepath_valueerror(self, path_str): + with pytest.raises(ValueError, match="path accesses node before root"): + TreePath(path_str) + + def test_div(self): + actual = str(TreePath("/a/b") / "..") + expected = "/a" + assert expected == actual + + def test_parent(self): + actual = str(TreePath("/a").parent) + expected = "/" + assert expected == actual + + class TestFamilyTree: def test_lonely(self) -> None: root: TreeNode = TreeNode() @@ -140,10 +211,10 @@ def test_parents(self) -> None: class TestGetNodes: def test_get_child(self) -> None: - john: TreeNode = TreeNode( + john = NamedNode( children={ - "Mary": TreeNode( - children={"Sue": TreeNode(children={"Steven": TreeNode()})} + "Mary": NamedNode( + children={"Sue": NamedNode(children={"Steven": NamedNode()})} ) } ) @@ -169,9 +240,9 @@ def test_get_child(self) -> None: assert mary._get_item("Sue/Steven") is steven def test_get_upwards(self) -> None: - john: TreeNode = TreeNode( + john = NamedNode( children={ - "Mary": TreeNode(children={"Sue": TreeNode(), "Kate": TreeNode()}) + "Mary": NamedNode(children={"Sue": NamedNode(), "Kate": NamedNode()}) } ) mary = john.children["Mary"] @@ -185,9 +256,7 @@ def test_get_upwards(self) -> None: assert sue._get_item("../Kate") is kate def test_get_from_root(self) -> None: - john: TreeNode = TreeNode( - children={"Mary": TreeNode(children={"Sue": TreeNode()})} - ) + john = NamedNode(children={"Mary": NamedNode(children={"Sue": NamedNode()})}) mary = john.children["Mary"] sue = mary.children["Sue"] @@ -196,38 +265,38 @@ def test_get_from_root(self) -> None: class TestSetNodes: def test_set_child_node(self) -> None: - john: TreeNode = TreeNode() - mary: TreeNode = TreeNode() + john = NamedNode() + mary = NamedNode() john._set_item("Mary", mary) assert john.children["Mary"] is mary - assert isinstance(mary, TreeNode) + assert isinstance(mary, NamedNode) assert mary.children == {} assert mary.parent is john def test_child_already_exists(self) -> None: - mary: TreeNode = TreeNode() - john: TreeNode = TreeNode(children={"Mary": mary}) - mary_2: TreeNode = TreeNode() + mary = NamedNode() + john = NamedNode(children={"Mary": mary}) + mary_2 = NamedNode() with pytest.raises(KeyError): john._set_item("Mary", mary_2, allow_overwrite=False) def test_set_grandchild(self) -> None: - rose: TreeNode = TreeNode() - mary: TreeNode = TreeNode() - john: TreeNode = TreeNode() + rose = NamedNode() + mary = NamedNode() + john = NamedNode() john._set_item("Mary", mary) john._set_item("Mary/Rose", rose) assert john.children["Mary"] is mary - assert isinstance(mary, TreeNode) + assert isinstance(mary, NamedNode) assert "Rose" in mary.children assert rose.parent is mary def test_create_intermediate_child(self) -> None: - john: TreeNode = TreeNode() - rose: TreeNode = TreeNode() + john = NamedNode() + rose = NamedNode() # test intermediate children not allowed with pytest.raises(KeyError, match="Could not reach"): @@ -237,25 +306,25 @@ def test_create_intermediate_child(self) -> None: john._set_item("Mary/Rose", rose, new_nodes_along_path=True) assert "Mary" in john.children mary = john.children["Mary"] - assert isinstance(mary, TreeNode) + assert isinstance(mary, NamedNode) assert mary.children == {"Rose": rose} assert rose.parent == mary assert rose.parent == mary def test_overwrite_child(self) -> None: - john: TreeNode = TreeNode() - mary: TreeNode = TreeNode() + john = NamedNode() + mary = NamedNode() john._set_item("Mary", mary) # test overwriting not allowed - marys_evil_twin: TreeNode = TreeNode() + marys_evil_twin = NamedNode() with pytest.raises(KeyError, match="Already a node object"): john._set_item("Mary", marys_evil_twin, allow_overwrite=False) assert john.children["Mary"] is mary assert marys_evil_twin.parent is None # test overwriting allowed - marys_evil_twin = TreeNode() + marys_evil_twin = NamedNode() john._set_item("Mary", marys_evil_twin, allow_overwrite=True) assert john.children["Mary"] is marys_evil_twin assert marys_evil_twin.parent is john @@ -263,8 +332,8 @@ def test_overwrite_child(self) -> None: class TestPruning: def test_del_child(self) -> None: - john: TreeNode = TreeNode() - mary: TreeNode = TreeNode() + john = NamedNode() + mary = NamedNode() john._set_item("Mary", mary) del john["Mary"] @@ -285,15 +354,15 @@ def create_test_tree() -> tuple[NamedNode, NamedNode]: # └── c # └── h # └── i - a: NamedNode = NamedNode(name="a") - b: NamedNode = NamedNode() - c: NamedNode = NamedNode() - d: NamedNode = NamedNode() - e: NamedNode = NamedNode() - f: NamedNode = NamedNode() - g: NamedNode = NamedNode() - h: NamedNode = NamedNode() - i: NamedNode = NamedNode() + a = NamedNode(name="a") + b = NamedNode() + c = NamedNode() + d = NamedNode() + e = NamedNode() + f = NamedNode() + g = NamedNode() + h = NamedNode() + i = NamedNode() a.children = {"b": b, "c": c} b.children = {"d": d, "e": e} @@ -504,9 +573,3 @@ def test_render_nodetree(self) -> None: assert len(john_nodes) == len(expected_nodes) for expected_node, repr_node in zip(expected_nodes, john_nodes, strict=True): assert expected_node == repr_node - - -def test_nodepath(): - path = NodePath("/Mary") - assert path.root == "/" - assert path.stem == "Mary"