Skip to content

Commit 14144cb

Browse files
Ensure keras.ops.eye behavior is consistent across backends. (#21738)
* ensure eye behavior is consistent across backends * Update keras/src/ops/numpy_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * simplify per pr review * pre-commit * fix test for torch backend + add comments * update implementation to raise TypeError for consistency * add case for M being the onl float * improve naming of inner function for type check --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent c3f2e93 commit 14144cb

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

keras/src/ops/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7228,6 +7228,19 @@ def eye(N, M=None, k=0, dtype=None):
72287228
Returns:
72297229
Tensor with ones on the k-th diagonal and zeros elsewhere.
72307230
"""
7231+
7232+
def is_floating_type(v):
7233+
return (
7234+
isinstance(v, float)
7235+
or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES
7236+
)
7237+
7238+
if is_floating_type(N):
7239+
raise TypeError("Argument `N` must be an integer or an integer tensor.")
7240+
if is_floating_type(M):
7241+
raise TypeError(
7242+
"Argument `M` must be an integer, an integer tensor, or `None`."
7243+
)
72317244
return backend.numpy.eye(N, M=M, k=k, dtype=dtype)
72327245

72337246

keras/src/ops/numpy_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5230,6 +5230,19 @@ def test_eye(self):
52305230
# Test k < 0 and M < N and M - k > N
52315231
self.assertAllClose(knp.eye(4, 3, k=-2), np.eye(4, 3, k=-2))
52325232

5233+
def test_eye_raises_error_with_floats(self):
5234+
with self.assertRaises(TypeError):
5235+
knp.eye(3.0)
5236+
with self.assertRaises(TypeError):
5237+
knp.eye(3.0, 2.0)
5238+
with self.assertRaises(TypeError):
5239+
knp.eye(3, 2.0)
5240+
with self.assertRaises(TypeError):
5241+
v = knp.max(knp.arange(4.0))
5242+
knp.eye(v)
5243+
with self.assertRaises(TypeError):
5244+
knp.eye(knp.array(3, dtype="bfloat16"))
5245+
52335246
def test_arange(self):
52345247
self.assertAllClose(knp.arange(3), np.arange(3))
52355248
self.assertAllClose(knp.arange(3, 7), np.arange(3, 7))

0 commit comments

Comments
 (0)