Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,22 +627,37 @@ def unified_dim_sizes(
) -> dict[Hashable, int]:
dim_sizes: dict[Hashable, int] = {}

exclude_dims_lookup = (
exclude_dims
if isinstance(exclude_dims, (set, frozenset))
else set(exclude_dims)
)

for var in variables:
if len(set(var.dims)) < len(var.dims):
dims = var.dims
shape = var.shape

# More efficient duplicate detection: skip set conversion for 0 or 1 dims
if len(dims) > 1 and len(set(dims)) < len(dims):
raise ValueError(
"broadcasting cannot handle duplicate "
f"dimensions on a variable: {list(var.dims)}"
f"dimensions on a variable: {list(dims)}"
)
for dim, size in zip(var.dims, var.shape):
if dim not in exclude_dims:
if dim not in dim_sizes:
dim_sizes[dim] = size
elif dim_sizes[dim] != size:
raise ValueError(
"operands cannot be broadcast together "
"with mismatched lengths for dimension "
f"{dim}: {dim_sizes[dim]} vs {size}"
)

for i in range(len(dims)):
dim = dims[i]
if dim in exclude_dims_lookup:
continue
size = shape[i]
existing_size = dim_sizes.get(dim)
if existing_size is None:
dim_sizes[dim] = size
elif existing_size != size:
raise ValueError(
"operands cannot be broadcast together "
"with mismatched lengths for dimension "
f"{dim}: {existing_size} vs {size}"
)
return dim_sizes


Expand Down