Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,49 +174,60 @@ def compute_conv_output_shape(
dilation_rate=1,
):
"""Compute the output shape of conv ops."""
# Avoid repeated attribute access and enable faster tuple ops
ndim = len(input_shape)
if data_format == "channels_last":
spatial_shape = input_shape[1:-1]
kernel_shape = kernel_size + (input_shape[-1], filters)
else:
spatial_shape = input_shape[2:]
kernel_shape = kernel_size + (input_shape[1], filters)
if len(kernel_shape) != len(input_shape):

if len(kernel_shape) != ndim:
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel_shape} and "
f"input of shape {input_shape}."
)

# Convert strides/dilation_rate only if not already tuple
spatial_ndim = len(spatial_shape)
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * len(spatial_shape)
dilation_rate = (dilation_rate,) * spatial_ndim
if isinstance(strides, int):
strides = (strides,) * len(spatial_shape)
if len(dilation_rate) != len(spatial_shape):
strides = (strides,) * spatial_ndim
if len(dilation_rate) != spatial_ndim:
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={dilation_rate}` and "
f"input of shape {input_shape}."
)
none_dims = []
spatial_shape = np.array(spatial_shape)
for i in range(len(spatial_shape)):
if spatial_shape[i] is None:
# Set `None` shape to a manual value so that we can run numpy
# computation on `spatial_shape`.
spatial_shape[i] = -1
tmp_spatial_shape = list(spatial_shape) # Use list for mutability and performance on small n

for i in range(spatial_ndim):
# This "is None" is as fast as possible
if tmp_spatial_shape[i] is None:
tmp_spatial_shape[i] = -1 # Use -1 to retain behavior
none_dims.append(i)

kernel_spatial_shape = np.array(kernel_shape[:-2])
dilation_rate = np.array(dilation_rate)
# Convert what is needed to ndarray for vectorized math (still fast for small d)
spatial_arr = np.fromiter(tmp_spatial_shape, dtype=int, count=spatial_ndim)
kernel_spatial_shape = np.fromiter(kernel_shape[:-2], dtype=int, count=spatial_ndim)
dilation_rate_arr = np.fromiter(dilation_rate, dtype=int, count=spatial_ndim)
strides_arr = np.fromiter(strides, dtype=int, count=spatial_ndim)

if padding == "valid":
output_spatial_shape = (
# ((S - Kd - 1) // stride) + 1, with Kd = dilation * (k-1)
output_spatial_shape_arr = (
np.floor(
(spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
/ strides
)
+ 1
(spatial_arr - dilation_rate_arr * (kernel_spatial_shape - 1) - 1)
/ strides_arr
) + 1
)
for i in range(len(output_spatial_shape)):
output_spatial_shape = output_spatial_shape_arr.tolist()
for i in range(spatial_ndim):
if i not in none_dims and output_spatial_shape[i] < 0:
raise ValueError(
"Computed output size would be negative. Received "
Expand All @@ -225,22 +236,28 @@ def compute_conv_output_shape(
f"`dilation_rate={dilation_rate}`."
)
elif padding == "same" or padding == "causal":
output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
output_spatial_shape = (
np.floor((spatial_arr - 1) / strides_arr) + 1
).tolist()
else:
raise ValueError(
"`padding` must be either `'valid'` or `'same'`. Received "
f"{padding}."
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
for i in none_dims:
output_spatial_shape[i] = None
output_spatial_shape = tuple(output_spatial_shape)

# Convert float to int eagerly
output_spatial_shape_int = [int(x) for x in output_spatial_shape]
for idx in none_dims:
output_spatial_shape_int[idx] = None

output_spatial_shape_tuple = tuple(output_spatial_shape_int)
if data_format == "channels_last":
output_shape = (
(input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
(input_shape[0],) + output_spatial_shape_tuple + (kernel_shape[-1],)
)
else:
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape_tuple

return output_shape


Expand Down