diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..a078c318943a 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -174,23 +174,29 @@ def compute_conv_output_shape( dilation_rate=1, ): """Compute the output shape of conv ops.""" + # Avoid repeated attribute access and enable faster tuple ops + ndim = len(input_shape) if data_format == "channels_last": spatial_shape = input_shape[1:-1] kernel_shape = kernel_size + (input_shape[-1], filters) else: spatial_shape = input_shape[2:] kernel_shape = kernel_size + (input_shape[1], filters) - if len(kernel_shape) != len(input_shape): + + if len(kernel_shape) != ndim: raise ValueError( "Kernel shape must have the same length as input, but received " f"kernel of shape {kernel_shape} and " f"input of shape {input_shape}." ) + + # Convert strides/dilation_rate only if not already tuple + spatial_ndim = len(spatial_shape) if isinstance(dilation_rate, int): - dilation_rate = (dilation_rate,) * len(spatial_shape) + dilation_rate = (dilation_rate,) * spatial_ndim if isinstance(strides, int): - strides = (strides,) * len(spatial_shape) - if len(dilation_rate) != len(spatial_shape): + strides = (strides,) * spatial_ndim + if len(dilation_rate) != spatial_ndim: raise ValueError( "Dilation must be None, scalar or tuple/list of length of " "inputs' spatial shape, but received " @@ -198,25 +204,30 @@ def compute_conv_output_shape( f"input of shape {input_shape}." ) none_dims = [] - spatial_shape = np.array(spatial_shape) - for i in range(len(spatial_shape)): - if spatial_shape[i] is None: - # Set `None` shape to a manual value so that we can run numpy - # computation on `spatial_shape`. - spatial_shape[i] = -1 + tmp_spatial_shape = list(spatial_shape) # Use list for mutability and performance on small n + + for i in range(spatial_ndim): + # This "is None" is as fast as possible + if tmp_spatial_shape[i] is None: + tmp_spatial_shape[i] = -1 # Use -1 to retain behavior none_dims.append(i) - kernel_spatial_shape = np.array(kernel_shape[:-2]) - dilation_rate = np.array(dilation_rate) + # Convert what is needed to ndarray for vectorized math (still fast for small d) + spatial_arr = np.fromiter(tmp_spatial_shape, dtype=int, count=spatial_ndim) + kernel_spatial_shape = np.fromiter(kernel_shape[:-2], dtype=int, count=spatial_ndim) + dilation_rate_arr = np.fromiter(dilation_rate, dtype=int, count=spatial_ndim) + strides_arr = np.fromiter(strides, dtype=int, count=spatial_ndim) + if padding == "valid": - output_spatial_shape = ( + # ((S - Kd - 1) // stride) + 1, with Kd = dilation * (k-1) + output_spatial_shape_arr = ( np.floor( - (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1) - / strides - ) - + 1 + (spatial_arr - dilation_rate_arr * (kernel_spatial_shape - 1) - 1) + / strides_arr + ) + 1 ) - for i in range(len(output_spatial_shape)): + output_spatial_shape = output_spatial_shape_arr.tolist() + for i in range(spatial_ndim): if i not in none_dims and output_spatial_shape[i] < 0: raise ValueError( "Computed output size would be negative. Received " @@ -225,22 +236,28 @@ def compute_conv_output_shape( f"`dilation_rate={dilation_rate}`." ) elif padding == "same" or padding == "causal": - output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 + output_spatial_shape = ( + np.floor((spatial_arr - 1) / strides_arr) + 1 + ).tolist() else: raise ValueError( "`padding` must be either `'valid'` or `'same'`. Received " f"{padding}." ) - output_spatial_shape = [int(i) for i in output_spatial_shape] - for i in none_dims: - output_spatial_shape[i] = None - output_spatial_shape = tuple(output_spatial_shape) + + # Convert float to int eagerly + output_spatial_shape_int = [int(x) for x in output_spatial_shape] + for idx in none_dims: + output_spatial_shape_int[idx] = None + + output_spatial_shape_tuple = tuple(output_spatial_shape_int) if data_format == "channels_last": output_shape = ( - (input_shape[0],) + output_spatial_shape + (kernel_shape[-1],) + (input_shape[0],) + output_spatial_shape_tuple + (kernel_shape[-1],) ) else: - output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape + output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape_tuple + return output_shape