diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..5f533cba8ab5 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -68,9 +68,10 @@ def compute_expand_dims_output_shape(input_shape, axis): axis = to_tuple_or_list(axis) out_ndim = len(axis) + len(input_shape) axis = [canonicalize_axis(a, out_ndim) for a in axis] + axis_set = set(axis) shape_iter = iter(input_shape) new_shape = [ - 1 if ax in axis else next(shape_iter) for ax in range(out_ndim) + 1 if ax in axis_set else next(shape_iter) for ax in range(out_ndim) ] return tuple(new_shape)