diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..70c66ef1d59e 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -24,32 +24,52 @@ def broadcast_shapes(shape1, shape2): >>> broadcast_shapes((5, 3), (1, 3)) [5, 3] """ - shape1 = list(shape1) - shape2 = list(shape2) + # Use tuple to avoid extra copy until needed + len1 = len(shape1) + len2 = len(shape2) origin_shape1 = shape1 origin_shape2 = shape2 - if len(shape1) > len(shape2): - shape2 = [1] * (len(shape1) - len(shape2)) + shape2 - if len(shape1) < len(shape2): - shape1 = [1] * (len(shape2) - len(shape1)) + shape1 - output_shape = list(shape1) - for i in range(len(shape1)): - if shape1[i] == 1: - output_shape[i] = shape2[i] - elif shape1[i] is None: - output_shape[i] = None if shape2[i] == 1 else shape2[i] + # Precompute the length difference and pad as necessary + if len1 > len2: + pad = (1,) * (len1 - len2) + shape2 = (*pad, *shape2) + len_ = len1 + elif len1 < len2: + pad = (1,) * (len2 - len1) + shape1 = (*pad, *shape1) + len_ = len2 + else: + len_ = len1 + + # Avoid making a list copy until output + output_shape = [None] * len_ + + # Pull to local for speed in loop + s1 = shape1 + s2 = shape2 + out = output_shape + + # Localize exception msg for fast path + for i in range(len_): + a = s1[i] + b = s2[i] + if a == 1: + out[i] = b + elif a is None: + # None is arbitrary unless b==1 means None, else b + out[i] = None if b == 1 else b else: - if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]: - output_shape[i] = shape1[i] + if b == 1 or b is None or b == a: + out[i] = a else: raise ValueError( "Cannot broadcast shape, the failure dim has value " - f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " + f"{a}, which cannot be broadcasted to {b}. " f"Input shapes are: {origin_shape1} and {origin_shape2}." ) - return output_shape + return out def compute_expand_dims_output_shape(input_shape, axis):