Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
from keras.src.ops.numpy import ldexp as ldexp
from keras.src.ops.numpy import left_shift as left_shift
from keras.src.ops.numpy import less as less
from keras.src.ops.numpy import less_equal as less_equal
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
from keras.src.ops.numpy import ldexp as ldexp
from keras.src.ops.numpy import left_shift as left_shift
from keras.src.ops.numpy import less as less
from keras.src.ops.numpy import less_equal as less_equal
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
from keras.src.ops.numpy import ldexp as ldexp
from keras.src.ops.numpy import left_shift as left_shift
from keras.src.ops.numpy import less as less
from keras.src.ops.numpy import less_equal as less_equal
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
from keras.src.ops.numpy import ldexp as ldexp
from keras.src.ops.numpy import left_shift as left_shift
from keras.src.ops.numpy import less as less
from keras.src.ops.numpy import less_equal as less_equal
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,19 @@ def lcm(x1, x2):
return jnp.lcm(x1, x2)


def ldexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)

if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
raise TypeError(
f"ldexp exponent must be an integer type. "
f"Received: x2 dtype={x2.dtype}"
)

return jnp.ldexp(x1, x2)


def less(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,19 @@ def lcm(x1, x2):
return np.lcm(x1, x2).astype(dtype)


def ldexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)

if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
raise TypeError(
f"ldexp exponent must be an integer type. "
f"Received: x2 dtype={x2.dtype}"
)
Comment on lines +777 to +785
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed on NumPy? Isn't the type promotion already consistent with jax.numpy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not have fully understood the review at first,
but if the suggestion was to simply return np.ldexp(x1, x2),
I tried that approach and it caused dtype mismatches with JAX in DtypeTest. So I think the explicit dtype handling is still required.

return np.ldexp(x1, x2).astype(dtype)


def less(x1, x2):
return np.less(x1, x2)

Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ NumpyDtypeTest::test_isin
NumpyDtypeTest::test_isreal
NumpyDtypeTest::test_kron
NumpyDtypeTest::test_lcm
NumpyDtypeTest::test_ldexp
NumpyDtypeTest::test_logaddexp2
NumpyDtypeTest::test_matmul_
NumpyDtypeTest::test_maximum_python_types
Expand Down Expand Up @@ -108,6 +109,7 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
NumpyTwoInputOpsCorrectnessTest::test_isin
NumpyTwoInputOpsCorrectnessTest::test_kron
NumpyTwoInputOpsCorrectnessTest::test_lcm
NumpyTwoInputOpsCorrectnessTest::test_ldexp
NumpyTwoInputOpsCorrectnessTest::test_quantile
NumpyTwoInputOpsCorrectnessTest::test_tensordot
NumpyTwoInputOpsCorrectnessTest::test_vdot
Expand All @@ -131,11 +133,13 @@ NumpyTwoInputOpsDynamicShapeTest::test_hypot
NumpyTwoInputOpsDynamicShapeTest::test_isin
NumpyTwoInputOpsDynamicShapeTest::test_kron
NumpyTwoInputOpsDynamicShapeTest::test_lcm
NumpyTwoInputOpsDynamicShapeTest::test_ldexp
NumpyTwoInputOpsStaticShapeTest::test_gcd
NumpyTwoInputOpsStaticShapeTest::test_hypot
NumpyTwoInputOpsStaticShapeTest::test_isin
NumpyTwoInputOpsStaticShapeTest::test_kron
NumpyTwoInputOpsStaticShapeTest::test_lcm
NumpyTwoInputOpsStaticShapeTest::test_ldexp
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
CoreOpsBehaviorTests::test_scan_invalid_arguments
CoreOpsCallsTests::test_associative_scan_basic_call
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,10 @@ def lcm(x1, x2):
raise NotImplementedError("`lcm` is not supported with openvino backend")


def ldexp(x1, x2):
raise NotImplementedError("`ldexp` is not supported with openvino backend")


def less(x1, x2):
element_type = None
if isinstance(x1, OpenVINOKerasTensor):
Expand Down
17 changes: 17 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,23 @@ def lcm(x1, x2):
return result


def ldexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)

if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
raise TypeError(
f"ldexp exponent must be an integer type. "
f"Received: x2 dtype={x2.dtype}"
)

x1 = tf.cast(x1, dtype)
x2 = tf.cast(x2, x1.dtype)
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)


def less(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
14 changes: 14 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,20 @@ def lcm(x1, x2):
return torch.lcm(x1, x2)


def ldexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)

if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
raise TypeError(
f"ldexp exponent must be an integer type. "
f"Received: x2 dtype={x2.dtype}"
)

return cast(torch.ldexp(x1, x2), dtype)


def less(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
return torch.less(x1, x2)
Expand Down
40 changes: 40 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4059,6 +4059,46 @@ def lcm(x1, x2):
return backend.numpy.lcm(x1, x2)


class Ldexp(Operation):
def call(self, x1, x2):
return backend.numpy.ldexp(x1, x2)

def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)

x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1)))
x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2)))
dtype = dtypes.result_type(x1_type, x2_type, float)
return KerasTensor(output_shape, dtype=dtype)


@keras_export(["keras.ops.ldexp", "keras.ops.numpy.ldexp"])
def ldexp(x1, x2):
"""Multiply `x1` by 2 raised to the power of `x2`, element-wise.

This function computes:
ldexp(x1, x2) = x1 * 2**x2

Args:
x1: Float input tensor.
x2: Integer exponent tensor.

Returns:
Output tensor

Example:
>>> x1 = keras.ops.convert_to_tensor([0.75, 1.5])
>>> x2 = keras.ops.convert_to_tensor([1, 2])
>>> keras.ops.ldexp(x1, x2)
array([1.5, 6. ], dtype=float32)
"""
if any_symbolic_tensors((x1, x2)):
return Ldexp().symbolic_call(x1, x2)
return backend.numpy.ldexp(x1, x2)


class Less(Operation):
def call(self, x1, x2):
return backend.numpy.less(x1, x2)
Expand Down
41 changes: 41 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ def test_lcm(self):
y = KerasTensor((2, None))
self.assertEqual(knp.lcm(x, y).shape, (2, 3))

def test_ldexp(self):
x = KerasTensor((None, 3))
y = KerasTensor((1, 3))
self.assertEqual(knp.ldexp(x, y).shape, (None, 3))

def test_less(self):
x = KerasTensor((None, 3))
y = KerasTensor((2, None))
Expand Down Expand Up @@ -837,6 +842,15 @@ def test_lcm(self):
y = KerasTensor((2, 3))
self.assertEqual(knp.lcm(x, y).shape, (2, 3))

def test_ldexp(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
self.assertEqual(knp.ldexp(x, y).shape, (2, 3))

x = KerasTensor((2, 3))
y = KerasTensor((1, 3))
self.assertEqual(knp.ldexp(x, y).shape, (2, 3))

def test_less(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
Expand Down Expand Up @@ -3114,6 +3128,12 @@ def test_lcm(self):
self.assertAllClose(knp.lcm(x, y), np.lcm(x, y))
self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y))

def test_ldexp(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
self.assertAllClose(knp.ldexp(x, y), np.ldexp(x, y))
self.assertAllClose(knp.Ldexp()(x, y), np.ldexp(x, y))

def test_less(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
Expand Down Expand Up @@ -7884,6 +7904,27 @@ def test_lcm(self, dtypes):
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=list(itertools.product(ALL_DTYPES, INT_DTYPES)))
)
def test_ldexp(self, dtypes):
import jax.numpy as jnp

dtype1, dtype2 = dtypes
x1 = knp.ones((), dtype=dtype1)
x2 = knp.ones((), dtype=dtype2)
x1_jax = jnp.ones((), dtype=dtype1)
x2_jax = jnp.ones((), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.ldexp(x1_jax, x2_jax).dtype)

self.assertEqual(
standardize_dtype(knp.ldexp(x1, x2).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Ldexp().symbolic_call(x1, x2).dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
Expand Down