diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index fb809c2cc7b2..6184e569043b 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -1,5 +1,4 @@ import functools -import operator import re import warnings @@ -262,14 +261,18 @@ def compute_conv_transpose_output_shape( def canonicalize_axis(axis, num_dims): """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" - axis = operator.index(axis) - if not -num_dims <= axis < num_dims: + # Faster than operator.index() as we avoid function call overhead + try: + axis = axis.__index__() + except AttributeError: + raise TypeError(f"axis must be an integer, got {type(axis)}") + if axis < -num_dims or axis >= num_dims: raise ValueError( f"axis {axis} is out of bounds for an array with dimension " f"{num_dims}." ) if axis < 0: - axis = axis + num_dims + axis += num_dims return axis diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..fb776db8c5d0 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -370,22 +370,25 @@ def compute_take_along_axis_output_shape(input_shape, indices_shape, axis): def reduce_shape(shape, axis=None, keepdims=False): shape = list(shape) + n = len(shape) if axis is None: if keepdims: - return tuple([1 for _ in shape]) + return (1,) * n else: - return tuple([]) + return () elif isinstance(axis, int): axis = (axis,) - axis = tuple(canonicalize_axis(a, len(shape)) for a in axis) + axis_tuple = tuple(canonicalize_axis(a, n) for a in axis) + if keepdims: - for ax in axis: + for ax in axis_tuple: shape[ax] = 1 return tuple(shape) else: - for ax in sorted(axis, reverse=True): + # Pre-sort axis indices once for deletion in reverse order + for ax in sorted(axis_tuple, reverse=True): del shape[ax] return tuple(shape)