diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..ccbb151a62ed 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -337,16 +337,18 @@ def compute_transpose_output_shape(input_shape, axes): Returns: Tuple of ints: The output shape after the `transpose` operation. """ - input_shape = list(input_shape) + # Convert input_shape to a tuple just once, for indexing and to avoid unnecessary copy + shape = tuple(input_shape) if axes is None: - return tuple(input_shape[::-1]) + return shape[::-1] - if len(axes) != len(input_shape): + if len(axes) != len(shape): raise ValueError( "axis must be a list of the same length as the input shape, " - f"expected {len(input_shape)}, but received {len(axes)}." + f"expected {len(shape)}, but received {len(axes)}." ) - return tuple(input_shape[ax] for ax in axes) + # Use tuple comprehension for fast construction + return tuple(shape[ax] for ax in axes) def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):