Skip to content

Commit 9fc2e96

Browse files
lukebaumanncopybara-github
authored andcommitted
Avoid using relatively new jax.extend package until all modules that we use are part of a stable JAX release.
PiperOrigin-RevId: 676890866
1 parent 0b5c336 commit 9fc2e96

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pathwaysutils/plugin_executable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import jax
2121
from jax._src.interpreters import pxla
22-
from jax.extend.ifrt_programs import ifrt_programs
22+
from jax._src.lib.xla_extension import ifrt_programs
2323

2424

2525
class PluginExecutable:

pathwaysutils/proxy_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
"""Register the IFRT Proxy as a backend for JAX."""
1515

1616
import jax
17-
from jax.extend import backend
17+
from jax._src import xla_bridge
1818
from jax.lib.xla_extension import ifrt_proxy
1919

2020

2121
def register_backend_factory():
22-
backend.register_backend_factory(
22+
xla_bridge.register_backend_factory(
2323
"proxy",
2424
lambda: ifrt_proxy.get_client(
2525
jax.config.read("jax_backend_target"),

0 commit comments

Comments
 (0)