Skip to content

Commit f3a8540

Browse files
authored
Merge branch 'keras-team:master' into master
2 parents d38ddca + 74fba84 commit f3a8540

File tree

31 files changed

+844
-230
lines changed

31 files changed

+844
-230
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
from keras.src.ops.numpy import kaiser as kaiser
216216
from keras.src.ops.numpy import kron as kron
217217
from keras.src.ops.numpy import lcm as lcm
218+
from keras.src.ops.numpy import ldexp as ldexp
218219
from keras.src.ops.numpy import left_shift as left_shift
219220
from keras.src.ops.numpy import less as less
220221
from keras.src.ops.numpy import less_equal as less_equal

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from keras.src.ops.numpy import kaiser as kaiser
102102
from keras.src.ops.numpy import kron as kron
103103
from keras.src.ops.numpy import lcm as lcm
104+
from keras.src.ops.numpy import ldexp as ldexp
104105
from keras.src.ops.numpy import left_shift as left_shift
105106
from keras.src.ops.numpy import less as less
106107
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
from keras.src.ops.numpy import kaiser as kaiser
216216
from keras.src.ops.numpy import kron as kron
217217
from keras.src.ops.numpy import lcm as lcm
218+
from keras.src.ops.numpy import ldexp as ldexp
218219
from keras.src.ops.numpy import left_shift as left_shift
219220
from keras.src.ops.numpy import less as less
220221
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from keras.src.ops.numpy import kaiser as kaiser
102102
from keras.src.ops.numpy import kron as kron
103103
from keras.src.ops.numpy import lcm as lcm
104+
from keras.src.ops.numpy import ldexp as ldexp
104105
from keras.src.ops.numpy import left_shift as left_shift
105106
from keras.src.ops.numpy import less as less
106107
from keras.src.ops.numpy import less_equal as less_equal

keras/src/backend/common/backend_utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
9696
)
9797

9898
if torch_output_padding >= stride:
99-
raise ValueError(
100-
f"The padding arguments (padding={padding}) and "
101-
f"output_padding={output_padding}) lead to a Torch "
102-
f"output_padding ({torch_output_padding}) that is greater than "
103-
f"strides ({stride}). This is not supported. You can change the "
104-
f"padding arguments, kernel or stride, or run on another backend. "
99+
warnings.warn(
100+
f"Torch backend requires output_padding < stride. "
101+
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
102+
f"for stride {stride}.",
103+
UserWarning,
105104
)
105+
torch_output_padding = stride - 1
106106

107107
return torch_padding, torch_output_padding
108108

@@ -184,6 +184,22 @@ def compute_conv_transpose_padding_args_for_torch(
184184
torch_paddings.append(torch_padding)
185185
torch_output_paddings.append(torch_output_padding)
186186

187+
# --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
188+
corrected_output_paddings = []
189+
for s, op in zip(
190+
strides
191+
if isinstance(strides, (list, tuple))
192+
else [strides] * num_spatial_dims,
193+
torch_output_paddings,
194+
):
195+
max_allowed = max(0, s - 1)
196+
if op > max_allowed:
197+
corrected_output_paddings.append(max_allowed)
198+
else:
199+
corrected_output_paddings.append(op)
200+
201+
torch_output_paddings = corrected_output_paddings
202+
187203
return torch_paddings, torch_output_paddings
188204

189205

keras/src/backend/common/backend_utils_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ def test_valid_padding_with_output_padding(self):
170170
self.assertEqual(torch_padding, 0)
171171
self.assertEqual(torch_output_padding, 1)
172172

173+
def test_output_padding_clamped_for_torch_constraint(self):
174+
"""Test that output_padding is clamped
175+
when >= stride (Torch constraint).
176+
"""
177+
(
178+
torch_paddings,
179+
torch_output_paddings,
180+
) = compute_conv_transpose_padding_args_for_torch(
181+
input_shape=(1, 8, 8, 8, 16), # any shape
182+
kernel_shape=(2, 2, 2, 16, 32), # Keras kernel shape
183+
strides=1,
184+
padding="same",
185+
output_padding=1, # Keras wants this
186+
dilation_rate=1,
187+
)
188+
# Torch expects output_padding < stride (1)
189+
# so output_padding should be clamped to 0
190+
self.assertEqual(torch_output_paddings, [0, 0, 0])
191+
173192

174193
class GetOutputShapeGivenTFPaddingTest(test_case.TestCase):
175194
def test_valid_padding_without_output_padding(self):

keras/src/backend/common/variables.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def value(self):
276276
return self._maybe_autocast(self._value)
277277

278278
def assign(self, value):
279-
value = self._convert_to_tensor(value, dtype=self.dtype)
279+
value = self._convert_to_tensor(value, dtype=self._dtype)
280280
if not shape_equal(value.shape, self.shape):
281281
raise ValueError(
282282
"The shape of the target variable and "
@@ -599,7 +599,6 @@ def standardize_shape(shape):
599599
# `tf.TensorShape` may contain `Dimension` objects.
600600
# We need to convert the items in it to either int or `None`
601601
shape = shape.as_list()
602-
shape = tuple(shape)
603602

604603
if config.backend() == "jax":
605604
# Replace `_DimExpr` (dimension expression) with None
@@ -609,25 +608,37 @@ def standardize_shape(shape):
609608
None if jax_export.is_symbolic_dim(d) else d for d in shape
610609
)
611610

612-
if config.backend() == "torch":
613-
# `shape` might be `torch.Size`. We need to convert the items in it to
614-
# either int or `None`
615-
shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
616-
617-
for e in shape:
618-
if e is None:
611+
# Handle dimensions that are not ints and not None, verify they're >= 0.
612+
standardized_shape = []
613+
for d in shape:
614+
if d is None:
615+
standardized_shape.append(d)
619616
continue
620-
if not is_int_dtype(type(e)):
617+
618+
# Reject these even if they can be cast to int successfully.
619+
if isinstance(d, (str, float)):
621620
raise ValueError(
622621
f"Cannot convert '{shape}' to a shape. "
623-
f"Found invalid entry '{e}' of type '{type(e)}'. "
622+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
624623
)
625-
if e < 0:
624+
625+
try:
626+
# Cast numpy scalars, tf constant tensors, etc.
627+
d = int(d)
628+
except Exception as e:
629+
raise ValueError(
630+
f"Cannot convert '{shape}' to a shape. "
631+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
632+
) from e
633+
if d < 0:
626634
raise ValueError(
627635
f"Cannot convert '{shape}' to a shape. "
628636
"Negative dimensions are not allowed."
629637
)
630-
return shape
638+
standardized_shape.append(d)
639+
640+
# This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641+
return tuple(standardized_shape)
631642

632643

633644
def shape_equal(a_shape, b_shape):

0 commit comments

Comments
 (0)