Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,18 @@ def compute_transpose_output_shape(input_shape, axes):
Returns:
Tuple of ints: The output shape after the `transpose` operation.
"""
input_shape = list(input_shape)
# Convert input_shape to a tuple just once, for indexing and to avoid unnecessary copy
shape = tuple(input_shape)
if axes is None:
return tuple(input_shape[::-1])
return shape[::-1]

if len(axes) != len(input_shape):
if len(axes) != len(shape):
raise ValueError(
"axis must be a list of the same length as the input shape, "
f"expected {len(input_shape)}, but received {len(axes)}."
f"expected {len(shape)}, but received {len(axes)}."
)
return tuple(input_shape[ax] for ax in axes)
# Use tuple comprehension for fast construction
return tuple(shape[ax] for ax in axes)


def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
Expand Down