Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,52 @@ def broadcast_shapes(shape1, shape2):
>>> broadcast_shapes((5, 3), (1, 3))
[5, 3]
"""
shape1 = list(shape1)
shape2 = list(shape2)
# Use tuple to avoid extra copy until needed
len1 = len(shape1)
len2 = len(shape2)
origin_shape1 = shape1
origin_shape2 = shape2

if len(shape1) > len(shape2):
shape2 = [1] * (len(shape1) - len(shape2)) + shape2
if len(shape1) < len(shape2):
shape1 = [1] * (len(shape2) - len(shape1)) + shape1
output_shape = list(shape1)
for i in range(len(shape1)):
if shape1[i] == 1:
output_shape[i] = shape2[i]
elif shape1[i] is None:
output_shape[i] = None if shape2[i] == 1 else shape2[i]
# Precompute the length difference and pad as necessary
if len1 > len2:
pad = (1,) * (len1 - len2)
shape2 = (*pad, *shape2)
len_ = len1
elif len1 < len2:
pad = (1,) * (len2 - len1)
shape1 = (*pad, *shape1)
len_ = len2
else:
len_ = len1

# Avoid making a list copy until output
output_shape = [None] * len_

# Pull to local for speed in loop
s1 = shape1
s2 = shape2
out = output_shape

# Localize exception msg for fast path
for i in range(len_):
a = s1[i]
b = s2[i]
if a == 1:
out[i] = b
elif a is None:
# None is arbitrary unless b==1 means None, else b
out[i] = None if b == 1 else b
else:
if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:
output_shape[i] = shape1[i]
if b == 1 or b is None or b == a:
out[i] = a
else:
raise ValueError(
"Cannot broadcast shape, the failure dim has value "
f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. "
f"{a}, which cannot be broadcasted to {b}. "
f"Input shapes are: {origin_shape1} and {origin_shape2}."
)

return output_shape
return out


def compute_expand_dims_output_shape(input_shape, axis):
Expand Down