Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions xarray/computation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import functools
from collections import Counter
from collections.abc import (
Callable,
Hashable,
)
from collections.abc import Callable, Hashable
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
Expand All @@ -23,10 +20,7 @@
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import (
is_scalar,
parse_dims_as_set,
)
from xarray.core.utils import is_scalar, parse_dims_as_set
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -903,7 +897,9 @@ def _calc_idxminmax(
if not array.ndim:
raise ValueError("This function does not apply for scalars")

if dim is not None:
if dim is Ellipsis:
dim = array.dims
elif dim is not None:
pass # Use the dim if available
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
Expand All @@ -912,14 +908,19 @@ def _calc_idxminmax(
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(
f"Dimension {dim!r} not found in array dimensions {array.dims!r}"
)
if dim not in array.coords:
raise KeyError(
f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)
dim_is_str = isinstance(dim, str)
# Standardize to an iterable format
dims = [dim] if dim_is_str else dim

for _dim in dims:
if _dim not in array.dims:
raise KeyError(
f"Dimension {_dim!r} not found in array dimensions {array.dims!r}"
)
if _dim not in array.coords:
raise KeyError(
f"Dimension {_dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"
Expand All @@ -931,25 +932,30 @@ def _calc_idxminmax(

# This will run argmin or argmax.
index = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
# Force dictionary format in case of single dim so that we can iterate over it in for loop below
if dim_is_str:
index = {dim: index}

res = {}
for _dim, _da_idx in zip(dims, index.values(), strict=False):
# Handle chunked arrays (e.g. dask).
coord = array[_dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[_dim].data, chunks=((array.sizes[_dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[_dim].data, array.data))

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[dim].data, array.data))

res = index._replace(coord[(index.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# Copy attributes from argmin/argmax, if any
res.attrs = index.attrs
_res = _da_idx._replace(coord[(_da_idx.variable,)]).rename(_dim)
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
_res = _res.where(~allna, fill_value)
_res.attrs = _da_idx.attrs
res[_dim] = _res

if dim_is_str:
return res[dim]
return res
Loading
Loading