diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py index b1ac2621de0a..dcd326e625b0 100644 --- a/keras/src/ops/operation_utils.py +++ b/keras/src/ops/operation_utils.py @@ -116,7 +116,6 @@ def compute_pooling_output_shape( """ strides = pool_size if strides is None else strides input_shape_origin = list(input_shape) - input_shape = np.array(input_shape) if data_format == "channels_last": spatial_shape = input_shape[1:-1] else: @@ -129,6 +128,8 @@ def compute_pooling_output_shape( spatial_shape[i] = -1 none_dims.append(i) pool_size = np.array(pool_size) + spatial_shape = np.array(spatial_shape) + strides = np.array(strides) if padding == "valid": output_spatial_shape = ( np.floor((spatial_shape - pool_size) / strides) + 1