Skip to content

Commit 74fba84

Browse files
authored
Implement ldexp function in keras.ops (#21863)
* Add ldexp initial version * Add numpy_test for ldexp * Update code by gemini reveiw * Update code by review * Add cast for tensorflow return value * merge master
1 parent f2c00fe commit 74fba84

File tree

12 files changed

+150
-0
lines changed

12 files changed

+150
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
from keras.src.ops.numpy import kaiser as kaiser
216216
from keras.src.ops.numpy import kron as kron
217217
from keras.src.ops.numpy import lcm as lcm
218+
from keras.src.ops.numpy import ldexp as ldexp
218219
from keras.src.ops.numpy import left_shift as left_shift
219220
from keras.src.ops.numpy import less as less
220221
from keras.src.ops.numpy import less_equal as less_equal

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from keras.src.ops.numpy import kaiser as kaiser
102102
from keras.src.ops.numpy import kron as kron
103103
from keras.src.ops.numpy import lcm as lcm
104+
from keras.src.ops.numpy import ldexp as ldexp
104105
from keras.src.ops.numpy import left_shift as left_shift
105106
from keras.src.ops.numpy import less as less
106107
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
from keras.src.ops.numpy import kaiser as kaiser
216216
from keras.src.ops.numpy import kron as kron
217217
from keras.src.ops.numpy import lcm as lcm
218+
from keras.src.ops.numpy import ldexp as ldexp
218219
from keras.src.ops.numpy import left_shift as left_shift
219220
from keras.src.ops.numpy import less as less
220221
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from keras.src.ops.numpy import kaiser as kaiser
102102
from keras.src.ops.numpy import kron as kron
103103
from keras.src.ops.numpy import lcm as lcm
104+
from keras.src.ops.numpy import ldexp as ldexp
104105
from keras.src.ops.numpy import left_shift as left_shift
105106
from keras.src.ops.numpy import less as less
106107
from keras.src.ops.numpy import less_equal as less_equal

keras/src/backend/jax/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,19 @@ def lcm(x1, x2):
845845
return jnp.lcm(x1, x2)
846846

847847

848+
def ldexp(x1, x2):
849+
x1 = convert_to_tensor(x1)
850+
x2 = convert_to_tensor(x2)
851+
852+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
853+
raise TypeError(
854+
f"ldexp exponent must be an integer type. "
855+
f"Received: x2 dtype={x2.dtype}"
856+
)
857+
858+
return jnp.ldexp(x1, x2)
859+
860+
848861
def less(x1, x2):
849862
x1 = convert_to_tensor(x1)
850863
x2 = convert_to_tensor(x2)

keras/src/backend/numpy/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,19 @@ def lcm(x1, x2):
773773
return np.lcm(x1, x2).astype(dtype)
774774

775775

776+
def ldexp(x1, x2):
777+
x1 = convert_to_tensor(x1)
778+
x2 = convert_to_tensor(x2)
779+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
780+
781+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
782+
raise TypeError(
783+
f"ldexp exponent must be an integer type. "
784+
f"Received: x2 dtype={x2.dtype}"
785+
)
786+
return np.ldexp(x1, x2).astype(dtype)
787+
788+
776789
def less(x1, x2):
777790
return np.less(x1, x2)
778791

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ NumpyDtypeTest::test_isin
2929
NumpyDtypeTest::test_isreal
3030
NumpyDtypeTest::test_kron
3131
NumpyDtypeTest::test_lcm
32+
NumpyDtypeTest::test_ldexp
3233
NumpyDtypeTest::test_logaddexp2
3334
NumpyDtypeTest::test_matmul_
3435
NumpyDtypeTest::test_maximum_python_types
@@ -108,6 +109,7 @@ NumpyTwoInputOpsCorrectnessTest::test_inner
108109
NumpyTwoInputOpsCorrectnessTest::test_isin
109110
NumpyTwoInputOpsCorrectnessTest::test_kron
110111
NumpyTwoInputOpsCorrectnessTest::test_lcm
112+
NumpyTwoInputOpsCorrectnessTest::test_ldexp
111113
NumpyTwoInputOpsCorrectnessTest::test_quantile
112114
NumpyTwoInputOpsCorrectnessTest::test_tensordot
113115
NumpyTwoInputOpsCorrectnessTest::test_vdot
@@ -131,11 +133,13 @@ NumpyTwoInputOpsDynamicShapeTest::test_hypot
131133
NumpyTwoInputOpsDynamicShapeTest::test_isin
132134
NumpyTwoInputOpsDynamicShapeTest::test_kron
133135
NumpyTwoInputOpsDynamicShapeTest::test_lcm
136+
NumpyTwoInputOpsDynamicShapeTest::test_ldexp
134137
NumpyTwoInputOpsStaticShapeTest::test_gcd
135138
NumpyTwoInputOpsStaticShapeTest::test_hypot
136139
NumpyTwoInputOpsStaticShapeTest::test_isin
137140
NumpyTwoInputOpsStaticShapeTest::test_kron
138141
NumpyTwoInputOpsStaticShapeTest::test_lcm
142+
NumpyTwoInputOpsStaticShapeTest::test_ldexp
139143
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
140144
CoreOpsBehaviorTests::test_scan_invalid_arguments
141145
CoreOpsCallsTests::test_associative_scan_basic_call

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,10 @@ def lcm(x1, x2):
11491149
raise NotImplementedError("`lcm` is not supported with openvino backend")
11501150

11511151

1152+
def ldexp(x1, x2):
1153+
raise NotImplementedError("`ldexp` is not supported with openvino backend")
1154+
1155+
11521156
def less(x1, x2):
11531157
element_type = None
11541158
if isinstance(x1, OpenVINOKerasTensor):

keras/src/backend/tensorflow/numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,23 @@ def lcm(x1, x2):
18441844
return result
18451845

18461846

1847+
def ldexp(x1, x2):
1848+
x1 = convert_to_tensor(x1)
1849+
x2 = convert_to_tensor(x2)
1850+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1851+
1852+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
1853+
raise TypeError(
1854+
f"ldexp exponent must be an integer type. "
1855+
f"Received: x2 dtype={x2.dtype}"
1856+
)
1857+
1858+
x1 = tf.cast(x1, dtype)
1859+
x2 = tf.cast(x2, x1.dtype)
1860+
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
1861+
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
1862+
1863+
18471864
def less(x1, x2):
18481865
x1 = convert_to_tensor(x1)
18491866
x2 = convert_to_tensor(x2)

keras/src/backend/torch/numpy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,20 @@ def lcm(x1, x2):
975975
return torch.lcm(x1, x2)
976976

977977

978+
def ldexp(x1, x2):
979+
x1 = convert_to_tensor(x1)
980+
x2 = convert_to_tensor(x2)
981+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
982+
983+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
984+
raise TypeError(
985+
f"ldexp exponent must be an integer type. "
986+
f"Received: x2 dtype={x2.dtype}"
987+
)
988+
989+
return cast(torch.ldexp(x1, x2), dtype)
990+
991+
978992
def less(x1, x2):
979993
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
980994
return torch.less(x1, x2)

0 commit comments

Comments
 (0)