Skip to content

Conversation

@shashaka
Copy link
Collaborator

Adds keras.ops.ldexp, which multiplies a floating-point tensor by 2 raised to the power of an integer tensor, element-wise.
Supported across NumPy, TensorFlow, PyTorch, and JAX backends.
Not supported on OpenVINO.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @shashaka, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the ldexp function into keras.ops, enabling users to efficiently perform element-wise multiplication of a float tensor by a power of two. The implementation ensures consistent behavior across major Keras backends like NumPy, TensorFlow, PyTorch, and JAX, while clearly defining its unavailability for the OpenVINO backend. This addition enhances Keras's numerical operations with a specialized function for floating-point manipulation.

Highlights

  • New Functionality: Introduced keras.ops.ldexp, a new function that multiplies a floating-point tensor by 2 raised to the power of an integer tensor, element-wise.
  • Multi-Backend Support: The ldexp function is implemented and supported across NumPy, TensorFlow, PyTorch, and JAX backends, ensuring broad compatibility within the Keras ecosystem.
  • OpenVINO Exclusion: Explicitly marked ldexp as not supported for the OpenVINO backend, with corresponding test exclusions and NotImplementedError handling.
  • Input Validation: Added type checking to ensure that the exponent argument (x2) for ldexp is always an integer type, preventing potential runtime errors.
  • Comprehensive Testing: New test cases have been added to keras/src/ops/numpy_test.py to cover symbolic shape inference, concrete shape inference, correctness, and dtype inference for the ldexp function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements keras.ops.ldexp for NumPy, TensorFlow, PyTorch, and JAX backends. The overall structure and API definition are good. However, I've found some issues in the backend implementations related to dtype handling and potential precision loss which could lead to incorrect results. I've provided suggestions to improve the correctness and robustness of the implementations for the NumPy, TensorFlow, and PyTorch backends. The tests are well-structured and cover various cases.

@codecov-commenter
Copy link

codecov-commenter commented Nov 20, 2025

Codecov Report

❌ Patch coverage is 80.39216% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.57%. Comparing base (f2c00fe) to head (e2d5218).

Files with missing lines Patch % Lines
keras/src/backend/jax/numpy.py 66.66% 1 Missing and 1 partial ⚠️
keras/src/backend/numpy/numpy.py 71.42% 1 Missing and 1 partial ⚠️
keras/src/backend/tensorflow/numpy.py 80.00% 1 Missing and 1 partial ⚠️
keras/src/backend/torch/numpy.py 71.42% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/api/_tf_keras/keras/ops/numpy/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21863      +/-   ##
==========================================
- Coverage   82.57%   82.57%   -0.01%     
==========================================
  Files         577      577              
  Lines       59599    59650      +51     
  Branches     9351     9356       +5     
==========================================
+ Hits        49213    49254      +41     
- Misses       7978     7984       +6     
- Partials     2408     2412       +4     
Flag Coverage Δ
keras 82.38% <80.39%> (-0.01%) ⬇️
keras-jax 62.86% <49.01%> (-0.02%) ⬇️
keras-numpy 57.51% <50.98%> (-0.01%) ⬇️
keras-openvino 34.32% <19.60%> (-0.02%) ⬇️
keras-tensorflow 64.40% <56.86%> (-0.01%) ⬇️
keras-torch 63.56% <50.98%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the addition!

Comment on lines +777 to +785
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)

if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
raise TypeError(
f"ldexp exponent must be an integer type. "
f"Received: x2 dtype={x2.dtype}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed on NumPy? Isn't the type promotion already consistent with jax.numpy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not have fully understood the review at first,
but if the suggestion was to simply return np.ldexp(x1, x2),
I tried that approach and it caused dtype mismatches with JAX in DtypeTest. So I think the explicit dtype handling is still required.

x1 = tf.cast(x1, tf.float32 if not x1.dtype.is_floating else x1.dtype)
x2 = tf.cast(x2, x1.dtype)
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the tf.where needed here? What happens otherwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I referred to the JAX implementation for this and tested the behavior. The where ensures correct results in edge cases. Please see the examples below.

import tensorflow as tf

x1 = tf.constant([0], dtype=tf.float16)
x2 = tf.constant([999], dtype=tf.int32)

result_normal = x1 * (tf.constant(2.0, dtype=tf.float16) ** tf.cast(x2, tf.float16))
result_where = tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result_normal)

tf.print("Without _where:", result_normal)
tf.print("With _where:", result_where)

Without _where: [-nan]
With _where: [0]
import tensorflow as tf

x1 = tf.constant([np.inf], dtype=tf.float16)
x2 = tf.constant([-99999], dtype=tf.int32)

result_normal = x1 * (tf.constant(2.0, dtype=tf.float16) ** tf.cast(x2, tf.float16))
result_where = tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result_normal)

tf.print("Without _where:", result_normal)
tf.print("With _where:", result_where)

Without _where: [-nan]
With _where: [inf]

@shashaka
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the keras.ops.ldexp function, providing implementations for NumPy, TensorFlow, PyTorch, and JAX backends. The changes are well-structured, with corresponding tests for shape inference, correctness, and data types. The implementations for JAX, NumPy, and PyTorch correctly use their respective native ldexp functions. My main feedback is to refactor the TensorFlow backend implementation to use the native tf.math.ldexp function, which would simplify the code and improve its robustness.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Dec 1, 2025
@hertschuh hertschuh merged commit 74fba84 into keras-team:master Dec 1, 2025
11 of 12 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Dec 1, 2025
@shashaka shashaka deleted the ldexp branch December 2, 2025 04:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants