Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
2162 commits
Select commit Hold shift + click to select a range
795bbfc
Add fast-path for non-concrete Tracers in is_constant_dim to lower tr…
sharadmv Jun 6, 2025
c1bb095
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
Jun 6, 2025
64ba9bc
[Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_…
dimitar-asenov Jun 6, 2025
8478fe6
Port PartitionSpec to C++.
hawkinsp Jun 6, 2025
e4c8da1
Fix segfault if None is passed to PartitionSpec.__eq__.
hawkinsp Jun 6, 2025
09111c8
[Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
justinjfu Jun 6, 2025
7dd721a
Optimize jaxpr equation pretty-printing.
hawkinsp Jun 6, 2025
58a1937
[JAX] Allow registering callbacks to be called when backends are cleared
hyeontaek Jun 6, 2025
ec61161
PRNGKeyArray doesn't have a format field so assign the layout to be N…
yashk2810 Jun 6, 2025
c9289ae
Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
mattjj Jun 6, 2025
567d61e
Add committed property to MutableArray
yashk2810 Jun 6, 2025
3248d55
[Pallas TPU] Add custom_vjp_call lowering rule
sharadmv Jun 6, 2025
051386f
Small speedups to pretty-printing.
hawkinsp Jun 6, 2025
0d1edcc
Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
yashk2810 Jun 6, 2025
6ae2614
[Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
justinjfu Jun 7, 2025
3699aa7
[mutable-arrays] make partial_eval_jaxpr forward input-residuals
mattjj May 30, 2025
0e9c0e4
Merge pull request #29311 from mattjj:mutable-array-custom-vjp
Google-ML-Automation Jun 7, 2025
58faffd
Clarify argument order for lax.associative_scan when reverse=True.
carlosgmartin Jun 7, 2025
1cb18ec
Update XLA dependency to use revision
Google-ML-Automation Jun 7, 2025
cb2b217
fix vestigial change that caused breakage
mattjj Jun 7, 2025
0d1b1ef
Update XLA dependency to use revision
Google-ML-Automation Jun 8, 2025
c2a5690
Port pretty-printer to C++.
hawkinsp Jun 8, 2025
d69086d
Fix typo in error message.
Google-ML-Automation Jun 8, 2025
880dd13
Minor fix to doc for random.orthogonal.
carlosgmartin Jun 8, 2025
d8317b5
[Pallas][Easy] Terser printing of GridMapping unless debug is set.
Google-ML-Automation Jun 9, 2025
2281455
Don't trigger debug_infs in ndtri unless an inf is returned.
dfm Jun 9, 2025
a6d95ee
[jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
gnecula May 21, 2025
a572a99
Merge pull request #29321 from carlosgmartin:fix_random_orthogonal_doc
Google-ML-Automation Jun 9, 2025
74e362e
Merge pull request #29313 from carlosgmartin:document_associative_sca…
Google-ML-Automation Jun 9, 2025
6d729fe
Move jax/_src/api.py and associated files to their own BUILD rule
Jun 9, 2025
92d5fe8
Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
bartchr808 Jun 9, 2025
88de1e6
jax.nn.standardize: improve documentation
jakevdp Jun 9, 2025
52704b7
Merge pull request #29330 from dfm:ndtri-infs
Google-ML-Automation Jun 9, 2025
3157265
Merge pull request #29332 from jakevdp:standardize-doc
Google-ML-Automation Jun 9, 2025
42ea2ac
Move jax/_src/custom_batching.py to its own build rule
Jun 9, 2025
89e0c7e
Move jax/_src/earray.py to its own build rule
Jun 9, 2025
9651e60
[array-api] pin array-api-tests to 2025.05.23
jakevdp Jun 9, 2025
7c35300
Fix spelling error in the name of the input variable.
Google-ML-Automation Jun 9, 2025
39de715
[Mosaic GPU] Error when causal masking is used on cuda versions known…
Rifur13 Jun 9, 2025
9a59162
Move jax/_src/ffi.py to its own build rule
Jun 9, 2025
bfd0744
Move jax/_src/custom_partitioning.py to its own build rule
Jun 9, 2025
a7c67e5
Merge pull request #29344 from jakevdp:array-api-test-pin
Google-ML-Automation Jun 9, 2025
df50cd7
Move jax/_src/buffer_callback.py to its own build rule
Jun 9, 2025
31e9998
[mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
mattjj Jun 9, 2025
5f2ce85
Merge pull request #29352 from mattjj:mutable-array-custom-vjp-fixes
Google-ML-Automation Jun 9, 2025
a58b27c
Move jax/_src/shard_alike.py to its own build rule
Jun 10, 2025
14aaa45
[Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
Google-ML-Automation Jun 10, 2025
34cee96
add reference to pr-checklist
jenriver Jun 10, 2025
b22be86
Remove forward_compat check for alpha as it is past the support date.
Google-ML-Automation Jun 10, 2025
28c31b8
Update JAX test to not rely on ToString and instead check the Device …
toli-y Jun 10, 2025
65c14a2
[pallas] Fix shard_map + Megacore in TPU interpret mode.
jburnim Jun 9, 2025
c126ee3
Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
hawkinsp Jun 10, 2025
cc971e3
Automated Code Change
Google-ML-Automation Jun 10, 2025
b59a97f
Update XLA dependency to use revision
Google-ML-Automation Jun 10, 2025
297ee7c
Merge pull request #28916 from gnecula:tf_version
Google-ML-Automation Jun 10, 2025
31ef2cf
Improve batching for lax.platform_dependent
gnecula Jun 10, 2025
be23dcf
Move jax/_src/public_test_util.py to its own build rule
Jun 10, 2025
f053dfe
[Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
bchetioui Jun 10, 2025
fdf7a2c
Merge pull request #29257 from benquike:main
Google-ML-Automation Jun 10, 2025
4846ed2
Pass source_info to custom_staging_rules and into jaxpr inlining.
hawkinsp Jun 10, 2025
5d64c39
[mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
mattjj Jun 7, 2025
56f3293
Merge pull request #29353 from mattjj:mutable-array-custom-vjp-scan2
Google-ML-Automation Jun 10, 2025
14a9b59
Merge pull request #29358 from jenriver:documentation
Google-ML-Automation Jun 10, 2025
b999fab
[jax2tf] fix jax2tf sharding tests for shardy
liepieshov Jun 10, 2025
9e0472c
[Mosaic GPU] Fix test after a previous PR changed the config params.
Rifur13 Jun 10, 2025
c054135
Skip NumPy's `isClose` test for NumPy 2.3.0
jakeharmon8 Jun 10, 2025
03b0152
fix for a downstream breakage from #29353
mattjj Jun 10, 2025
89fca52
* Add support for output and input memory space colors in tpu custom …
subhankarshah Jun 10, 2025
bdd635a
[JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments.
yashk2810 Jun 10, 2025
160e59f
Add is_leaf_with_path predicate.
IvyZX Jun 10, 2025
423aafe
Pallas documentation fixes.
Google-ML-Automation Jun 10, 2025
ed03f38
Rollback of #29353 due to downstream failures
mattjj Jun 10, 2025
22ceb68
Initial commit to make unreduced + AD work.
yashk2810 Jun 10, 2025
d9e0244
[JAX] Move the fallback of `colocated_cpu_devices` logic from the col…
hyeontaek Jun 10, 2025
fd6d90a
Add basic mutable array tests with AOT
yashk2810 Jun 11, 2025
f211c6b
Save a jaxpr equation in pl.cdiv if the rhs is an int.
hawkinsp Jun 11, 2025
cd4a0c6
Update XLA dependency to use revision
Google-ML-Automation Jun 11, 2025
02cdc7b
Ensure that all attributes are restored after pickling in `NamedShard…
stjerngren Jun 11, 2025
9da048e
[Mosaic GPU] Use _slice_smem also for barriers.
dimitar-asenov Jun 11, 2025
de82f9f
[pallas:mosaic] A few more primitives now have lowerings for all kern…
superbobry Jun 11, 2025
827b855
[Mosaic GPU] Remove unneeded code.
dimitar-asenov Jun 11, 2025
0edfb44
Propagate source_info in more places:
hawkinsp Jun 11, 2025
1cd4992
Ensure that memory_kind is restored after pickling in SingleDeviceSha…
stjerngren Jun 11, 2025
cfebc49
Don't recompute np.iinfo in _scalar_type_to_dtype.
hawkinsp Jun 11, 2025
5f0e7e4
Set explicit dot precision in the sparse solver test.
mooskagh Jun 11, 2025
225f0c6
Migrated to mypy 1.16.0
superbobry Jun 11, 2025
f7f2ce5
Delete instantiate_const_abstracted.
hawkinsp Jun 11, 2025
f519ad7
Merge pull request #29350 from jburnim:jburnim_interpret_shard_map_pl…
Google-ML-Automation Jun 11, 2025
6c9bcfc
add doc comment to vma in ShapedArray
yashk2810 Jun 11, 2025
b87ea1c
Do not call update_weak_type on the result of get_aval().
hawkinsp Jun 11, 2025
772cde8
Move jax/_src/export to its own build rule
Jun 11, 2025
703cf91
Merge pull request #29294 from olupton:fix-scaled-matmul-stablehlo-test
Google-ML-Automation Jun 11, 2025
2618d9b
Move materialization of NDIndexer out of draw()
jakeharmon8 Jun 11, 2025
5c2a320
[JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices
hyeontaek Jun 11, 2025
9f8be25
extend pallas paged_attention with kv scales
rdyro Jun 9, 2025
e61ee7e
Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
Google-ML-Automation Jun 11, 2025
5b49b28
Merge pull request #29405 from superbobry:union-fix
Google-ML-Automation Jun 11, 2025
b40d79c
Merge pull request #29354 from rdyro:paged_attention_with_scales
Google-ML-Automation Jun 11, 2025
97ab9e0
[XProf] Change tensorboard-plugin-profile to new xprof package
Matt-Hurd May 30, 2025
004b548
Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
Google-ML-Automation Jun 11, 2025
d00b1ca
Merge pull request #29129 from Matt-Hurd:rename_to_xprof
Google-ML-Automation Jun 11, 2025
3ef4db4
Add a pytype disable around zstandard.
hawkinsp Jun 11, 2025
cc976cb
Add execution to unreduced tests now that it works end-to-end
yashk2810 Jun 11, 2025
45e61d8
Add nightly linux jax wheel tests for python 3.14.0b1
kanglant Jun 11, 2025
e7d252d
Fix GPU quantized paged attention tests for < sm89
rdyro Jun 12, 2025
2c467d6
Merge pull request #29423 from rdyro:gpu_paged_fix
Google-ML-Automation Jun 12, 2025
1886a9e
Merge pull request #29362 from gnecula:platform_index_linearize
Google-ML-Automation Jun 12, 2025
4ad4d4a
Update XLA dependency to use revision
Google-ML-Automation Jun 12, 2025
87ce7d1
add jax.nn module type hints (__init__.pyi)
DanisNone Jun 11, 2025
a288036
Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
hawkinsp Jun 12, 2025
83c292b
[Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
bchetioui Jun 12, 2025
ef10603
add missing dtypes to jax.numpy.__init__.pyi
DanisNone Jun 12, 2025
4e3bf29
Temporarily disable AVX512 in linalg_test_cpu.
hawkinsp Jun 12, 2025
294d86b
Add hermetic `nvshmem` dependencies to JAX targets.
Google-ML-Automation Jun 12, 2025
46b9ead
Merge pull request #29425 from DanisNone:miss-dtypes
Google-ML-Automation Jun 12, 2025
f81f258
[Pallas TPU] Support memory space constraints on pallas_call inputs.
sharadmv Jun 12, 2025
1228053
[JAX] Update the example to use jax.numpy rather than numpy.
Google-ML-Automation Jun 12, 2025
0719346
Reland the C++ safe_zip implementation.
hawkinsp Jun 12, 2025
6e2977d
Add `all_gather_invariant` to lax.
yashk2810 Jun 12, 2025
a9fdb76
[Pallas TPU] Small fix to memory space constraints on pallas_call inp…
Google-ML-Automation Jun 12, 2025
094b66f
[Mosaic GPU] Enable transpose tests in mosaic_gpu.
dimitar-asenov Jun 12, 2025
efbedc6
[doc] fix some inaccuracies in jnp.bincount docs
jakevdp Jun 12, 2025
abb756d
Add colorama back into test-requirements
jakeharmon8 Jun 12, 2025
109519a
Merge pull request #29441 from jakevdp:bincount-doc
Google-ML-Automation Jun 12, 2025
688c3d3
Use a frozenset for unconstrained_dims in sharding_constraint_p.
hawkinsp Jun 13, 2025
b72be57
[jaxlib] Change Traceback to be a raw CPython class rather than a nan…
hawkinsp Jun 13, 2025
e04cc28
Make the params of more jaxpr primitives hashable.
hawkinsp Jun 13, 2025
d840447
Remove unused internal optimization_barrier alias
Jun 13, 2025
382b3e0
fix-forward for pallas tpu memory spaces test
sharadmv Jun 13, 2025
c86fefb
Update XLA dependency to use revision
Google-ML-Automation Jun 13, 2025
142ace2
Move jax._src.callback to its own BUILD rule
Jun 13, 2025
604b604
[Mosaic GPU] Convert all memrefs with transforms to unrealized casts …
dimitar-asenov Jun 13, 2025
fc8192c
[Mosaic GPU] Add a Mosaic GPU op `with_transforms` for manually setti…
dimitar-asenov Jun 13, 2025
a4f0e40
[Mosaic GPU] Resolve different tile transforms using the largest comm…
dimitar-asenov Jun 13, 2025
70c90a9
[Mosaic GPU] Use warpgroup semantics for the ragged dot example kernel.
dimitar-asenov Jun 13, 2025
9b45aac
Disable `too_slow` in data.draw() for test_ndindexer
jakeharmon8 Jun 13, 2025
a3542f8
[Mosaic GPU] Reconcile the swizzle of the a and b operands for wgmma …
dimitar-asenov Jun 13, 2025
7b7a5d8
Add pjit_p to extend.core.primitives
j-towns Jun 13, 2025
e66a6dd
Fix return type annotation for tree_util.tree_broadcast.
jburnim Jun 13, 2025
193f11d
[Mosaic GPU] Parametrize the `test_subview` test.
dimitar-asenov Jun 13, 2025
670ae13
Make params of several pallas primitives hashable.
hawkinsp Jun 13, 2025
92aedb2
Improve reshape not supported error message
yashk2810 Jun 13, 2025
1b1e9f7
Internal refactor: move TPU lowering rules out of jax/_src/lax
jakevdp Jun 13, 2025
64c9574
Make params of assert_consumed_value_p hashable.
hawkinsp Jun 13, 2025
ec637fe
Merge pull request #29456 from j-towns:pjit-p-extend-primitives
Google-ML-Automation Jun 13, 2025
f41fc71
Merge pull request #28931 from olupton:cublas-cudnn
Google-ML-Automation Jun 13, 2025
29eb832
Merge pull request #29420 from jakevdp:tpu-linalg
Google-ML-Automation Jun 13, 2025
790540c
Make some remaining jaxpr equation params hashable.
hawkinsp Jun 13, 2025
b9cf0af
Implemented cross-host memory transfer on GPU.
mwhittaker Jun 13, 2025
080294c
Load CUDA libraries up front with cdll.LoadLibrary().
hawkinsp Jun 13, 2025
a849d21
Merge pull request #29462 from hawkinsp:nvrtc
Google-ML-Automation Jun 13, 2025
dfc9052
[Pallas] Add no_pipelining debugging option to emit_pipeline.
justinjfu Jun 13, 2025
8e346dd
Replace `with_partitions` and `with_unreduced` with `.update` on Part…
yashk2810 Jun 13, 2025
7904b86
Remove `with_spec` from NamedSharding and replace with `.update`
yashk2810 Jun 13, 2025
f06888f
[JAX] Fix the test names in colocated_python_test.py to following the…
hyeontaek Jun 14, 2025
e28d6ed
Update XLA dependency to use revision
Google-ML-Automation Jun 14, 2025
d065d2a
Make mosaic_gpu equation params hashable.
hawkinsp Jun 14, 2025
9678a76
Update XLA dependency to use revision
Google-ML-Automation Jun 15, 2025
b25655c
Update XLA dependency to use revision
Google-ML-Automation Jun 16, 2025
639216f
PR #28102: Add cudnn paged attention support in JAX cuDNN SDPA API
Cjkkkk Jun 16, 2025
53849a5
Add `update_vma` and `update_weak_type` override on AbstractTMEMRef s…
yashk2810 Jun 16, 2025
1362f7f
Add version guards to testAutoPgle
yashk2810 Jun 16, 2025
173574a
Bump the libtpu check to 6/20
berkinilbeyi Jun 16, 2025
0e7c96a
Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Pytho…
superbobry Jun 16, 2025
dbde6c4
[jax] Increase absolute test tolerance for lax_control_flow test
basioli-k Jun 16, 2025
49e52c0
Fix some more instances of unhashable jaxpr equation arguments.
hawkinsp Jun 16, 2025
e5af088
[export] Add back-compat test for tridiagonal solve on GPU
gnecula Jun 16, 2025
a78c6a7
Set heartbeat_timeout argument and flag.
mwhittaker Jun 16, 2025
09d903f
Install SciPy from its source (head) to test against Python 3.14.0b1
kanglant Jun 16, 2025
7aec14f
[doc] add missing axis_types documentation
jakevdp Jun 16, 2025
8865ee6
Remove legacy CPU custom calls.
dfm Jun 16, 2025
35ae958
Merge pull request #29497 from jakevdp:make-mesh-doc
Google-ML-Automation Jun 16, 2025
cb1cc37
Fix a missing bounds check in traceback code.
hawkinsp Jun 16, 2025
97e580c
Removed fixed suppressions
vfdev-5 Jun 16, 2025
748b39f
[Mosaic:TPU][NFC] Delete unused variable
tlongeri Jun 16, 2025
b622512
[JAX] Relax the return type of `colocated_python` decorator
hyeontaek Jun 16, 2025
2dfaced
Add custom-call ops to roofline.
zacmustin Jun 17, 2025
0f5cdba
Removing Tensorflow references from the document.
sannidhyachauhan Jun 17, 2025
acf99e5
Add test for programmatic tracing with options.
sannidhyachauhan Jun 17, 2025
4be1402
Update XLA dependency to use revision
Google-ML-Automation Jun 17, 2025
93453f5
Pass through the `use_shardy_partitioner` with `jax.config.jax_use_sh…
ZixuanJiang Jun 17, 2025
91651f8
[Mosaic] Use BF16 ops for math::PowF on TPUv6+.
WindQAQ Jun 17, 2025
077f2b6
Update Pallas debugging doc with TPU interpret mode + dynamic race de…
jburnim Jun 17, 2025
124c723
Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation …
Google-ML-Automation Jun 17, 2025
332aa35
Add an API to overwrite the current execution_stream_id and respect i…
cky9301 Jun 17, 2025
f22896a
jax.experimental.enable_x64: add warning to docstring
jakevdp Jun 17, 2025
0fd0821
[Pallas TPU] Add flag to enable using registers to keep track of slot…
Google-ML-Automation Jun 17, 2025
784be1f
add psend and precv to jax/lax/parallel
rosiezou May 30, 2025
c2cc9f9
[pallas] `AbstractMemoryRef` now implements all functional update met…
superbobry Jun 17, 2025
3d37b0d
Merge pull request #29504 from vfdev-5:tsan-ft-removed-fixed-suppression
Google-ML-Automation Jun 17, 2025
353e7fa
Merge pull request #29516 from jakevdp:enable-x64-warning
Google-ML-Automation Jun 17, 2025
dc9ef61
Merge pull request #29410 from DanisNone:nn-type
Google-ML-Automation Jun 17, 2025
02688e1
[Pallas][Mosaic GPU] Enable collective MMA from TMEM.
justinjfu Jun 17, 2025
e4de90e
Update XLA dependency to use revision
Google-ML-Automation Jun 17, 2025
8f81490
Prepare for JAX release 0.6.2
yashk2810 Jun 17, 2025
755bb67
Merge pull request #29135 from rosiezou:main
Google-ML-Automation Jun 17, 2025
7dd13d7
Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream p…
Jun 17, 2025
c944c65
Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.
zacmustin Jun 17, 2025
19f34a0
[JAX] Remove sleeping from colocated Python execution tests
hyeontaek Jun 17, 2025
feab6f4
Postrelease (0.6.2) changes
yashk2810 Jun 17, 2025
8b88cc8
Merge pull request #29528 from jax-ml:postrelease
Google-ML-Automation Jun 17, 2025
34d88cc
Add `gather` to `roofline`.
zacmustin Jun 18, 2025
1d07747
jaxlib_extension_version == 355 after 0.6.2 release. So remove the co…
yashk2810 Jun 18, 2025
366a7df
[Mosaic:TPU] Byte-granularity dynamic gathers
tlongeri Jun 18, 2025
7c432e9
[mosaic] Added a `k` prefix to `TPU_MemorySpace` members
superbobry Jun 18, 2025
2ec9981
[Mosaic GPU] Rework the CUDA_ROOT detection once again
apaszke Jun 18, 2025
b6575e1
[Mosaic GPU] Add support for s8 matmuls on Blackwell
apaszke Jun 18, 2025
1cd076f
Drop Python 3.10 support.
hawkinsp Jun 18, 2025
f562884
[Mosaic GPU] Implement canonicalization for `TiledLayout`s.
bchetioui Jun 18, 2025
cad4ba7
Merge pull request #29543 from hawkinsp:py310
Google-ML-Automation Jun 18, 2025
5e2afe6
Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and ne…
yashk2810 Jun 18, 2025
9cf81e4
Add a cache around abstract_eval rules.
hawkinsp Jun 18, 2025
3625696
Finalize a number of deprecations for JAX v0.7.0
jakevdp Jun 18, 2025
a7fe11d
Merge pull request #29549 from jakevdp:finalize-deps
Google-ML-Automation Jun 18, 2025
ffabd3e
[Mosaic TPU] Make the backward-compatibility libtpu condition stricter
Google-ML-Automation Jun 18, 2025
4f437a3
Remove some dangling references from the docs.
hawkinsp Jun 18, 2025
364f004
[Mosaic GPU] Fix minor error in matmul test.
justinjfu Jun 18, 2025
9859157
Merge pull request #29558 from hawkinsp:docs
Google-ML-Automation Jun 18, 2025
59034e8
Run pyupgrade --py311-plus.
hawkinsp Jun 18, 2025
4aa1db4
Remove PositionalSharding from JAX now that 0.6.2 is out and next rel…
yashk2810 Jun 18, 2025
4f871b0
Skip pytest for Python 3.14 during the JAX release process
kanglant Jun 18, 2025
7c9613a
Fix rare error with Literal in DynamicJaxprTracer.full_lower.
jburnim Jun 18, 2025
511bf2f
Move jax._src.lax to its own BUILD rule
Jun 18, 2025
4c54c02
Fix bugs in the double_buffered_pipeline example
dubstack Jun 18, 2025
0912143
Merge pull request #29550 from hawkinsp:py311
Google-ML-Automation Jun 18, 2025
cb2315a
Update jax requirements lock files after 0.6.2 release
kanglant Jun 18, 2025
8cbd915
Merge pull request #29554 from dubstack:patch-1
Google-ML-Automation Jun 18, 2025
6567192
[doc] fix build error
jakevdp Jun 18, 2025
8337fe5
Merge pull request #29563 from jakevdp:fix-doc
Google-ML-Automation Jun 18, 2025
5e290dd
[Pallas][Mosaic GPU] Add GPU pipelining docs
justinjfu Apr 19, 2025
d3f0871
Merge pull request #29560 from kanglant:update_lock_files
Google-ML-Automation Jun 18, 2025
a47ae57
Add wrap_negative_indices paramter to jnp.ndarray.at[]
jakevdp Jun 18, 2025
9d1b01e
[JAX] Skip failing tpu tests until June 30th.
subhankarshah Jun 18, 2025
d188366
add cudnn sdpa mla support
Cjkkkk Jun 18, 2025
0085221
Merge pull request #29434 from jakevdp:normalize-indices
Google-ML-Automation Jun 18, 2025
4d2b14a
Merge pull request #28135 from justinjfu:gpu_pipe_docs
Google-ML-Automation Jun 19, 2025
4efa56d
Merge pull request #28872 from Cjkkkk:jax_cudnn_sdpa_mla
Google-ML-Automation Jun 19, 2025
1e4a0f7
Pass shardy option through jax config.
ZixuanJiang Jun 19, 2025
0b54a1e
Reenable AVX512 after LLVM fix upstream.
WillFroom Jun 19, 2025
e55f55f
[Mosaic GPU] Delete dead code in `layout_inference.py`.
bchetioui Jun 19, 2025
b99d004
Create a test suite for Pallas mosaic GPU tests.
dimitar-asenov Jun 19, 2025
f99d2b4
[Pallas:MGPU] Add docs for pl.core_map and plgpu.kernel
apaszke Jun 19, 2025
84066b7
[Mosaic GPU] Change layout inference tests to rely on explicit `layou…
bchetioui Jun 19, 2025
e818940
[Mosaic GPU][NFC] Add `checkInLayouts` and `checkOutLayouts` utils to…
bchetioui Jun 19, 2025
bfc07e2
Remove `Layout`, `.layout`, `.input_layouts` and `.output_layouts` an…
yashk2810 Jun 19, 2025
3fdc97b
[Pallas/Mosaic GPU] Propagate transforms on the accumulator in `tcgen…
bchetioui Jun 19, 2025
346ce85
[pallas:mosaic] Fixed a typo in the distributed tutorial
superbobry Jun 19, 2025
f3370cb
[mosaic] `MemRef{Slice,Squeeze}` verifiers now support strided layouts
superbobry Jun 19, 2025
d46202e
Add some tracemes in py_array to make slow device put debugging easier.
Google-ML-Automation Jun 20, 2025
71ea45b
Add traceme to `PythonRefManager::CollectGarbage`
junwhanahn Jun 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
39 changes: 28 additions & 11 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build -c opt
build --output_filter=DONT_MATCH_ANYTHING

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
build --copt=-DNB_DOMAIN=jax

build --legacy_external_runfiles=false

# #############################################################################
# Platform Specific configs below. These are automatically picked up by Bazel
Expand Down Expand Up @@ -97,6 +100,7 @@ build:windows --incompatible_strict_action_env=true
# #############################################################################
build:nonccl --define=no_nccl_support=true

build --repo_env USE_PYWRAP_RULES=1
build:posix --copt=-fvisibility=hidden
build:posix --copt=-Wno-sign-compare
build:posix --cxxopt=-std=c++17
Expand Down Expand Up @@ -130,23 +134,27 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags
# Required when building with clang>=19, see jax-ml/jax#27091
build:clang --copt=-Wno-error=c23-extensions

# Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda

# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# Default hermetic CUDA, CUDNN and NVSHMEM versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This config is used for building targets with CUDA libraries from stubs.
# This config is used for building targets with CUDA/NVSHMEM libraries from stubs.
build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false
build:cuda_libraries_from_stubs --@local_config_nvshmem//:include_nvshmem_libs=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
Expand Down Expand Up @@ -238,6 +246,9 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
build:ci_linux_aarch64_base --color=yes

# This appears to help avoid a timeout in CI for linalg_test.
build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8

build:ci_linux_aarch64 --config=ci_linux_aarch64_base
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
Expand All @@ -260,8 +271,8 @@ build:ci_darwin_arm64 --color=yes
# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

Expand Down Expand Up @@ -321,6 +332,9 @@ build:rbe_linux_x86_64 --config=ci_linux_x86_64
build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base
build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda
build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1
# Speed up CUDA repos creation by downloading ".tar" dists from the mirror.
build:rbe_linux_x86_64_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1
build:rbe_linux_x86_64_cuda --repo_env=USE_NVSHMEM_TAR_ARCHIVE_FILES=1

# RBE configs for Windows
# Set the remote worker pool
Expand All @@ -329,9 +343,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst
build:rbe_windows_amd64 --config=rbe

# Set the host, execution, and target platform
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"

build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
build:rbe_windows_amd64 --enable_runfiles
Expand Down Expand Up @@ -371,6 +385,9 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/
build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64
build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base

# Avoids a timeout in linalg_test on ARM.
build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8

# Mac x86
build:cross_compile_darwin_x86_64 --config=cross_compile_base
build:cross_compile_darwin_x86_64 --config=nonccl
Expand Down Expand Up @@ -410,7 +427,7 @@ build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base
#############################################################################

build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3"
build:debug --config debug_symbols -c fastbuild
build:debug --config=debug_symbols -c fastbuild

# Load `.jax_configure.bazelrc` file written by build.py
try-import %workspace%/.jax_configure.bazelrc
Expand Down
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ body:

[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues

[Raw report]: http://github.com/jax-ml/jax/issues/new
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
- type: textarea
attributes:
label: Description
Expand Down
20 changes: 20 additions & 0 deletions .github/actionlint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Configuration related to self-hosted runner.
self-hosted-runner:
labels:
- "linux-x86-n2-32" # Linux X86 runner using the 32 vcpu n2-standard-32 machine.
- "linux-x86-n2-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine.
- "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached.
- "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached.
- "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology.
- "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine.
- "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine.
- "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine.
- "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine.
- "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine
- "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached.
- "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology.
- "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology.
- "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology.
- "linux-x86-n2-128" # Linux X86 runner using the 128 vcpu n2-standard-128 machine.
- "linux-x86-n2-16" # Linux X86 runner using the 16 vcpu n2-standard-16 machine.
- "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner
4 changes: 3 additions & 1 deletion .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- main
paths:
- '**/workflows/asan.yaml'

permissions: {}
jobs:
asan:
# Don't execute in fork due to runner type
Expand Down Expand Up @@ -41,11 +41,13 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
persist-credentials: false
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: python/cpython
path: cpython
ref: v3.13.0
persist-credentials: false
- name: Build CPython with ASAN enabled
env:
ASAN_OPTIONS: detect_leaks=0
Expand Down
60 changes: 60 additions & 0 deletions .github/workflows/bazel_cpu_py_import_rbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# CI - Bazel CPU tests with py_import (RBE)
#
# This workflow runs the Bazel CPU tests with py_import dependency. It can only be triggered by
# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` workflows
# to run the Bazel CPU tests.
#
# It consists of the following job:
# run-tests:
# - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions:
# - Runs the Bazel CPU tests with py_import dependency.
name: CI - Bazel CPU tests with py_import (RBE)
permissions: {}
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
default: "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'

jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') ||
(contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}

name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}"

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CPU tests with py_import (RBE)
timeout-minutes: 60
run: ./ci/run_bazel_test_cpu_py_import_rbe.sh
23 changes: 15 additions & 8 deletions .github/workflows/bazel_cpu_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
branches:
- main
- 'release/**'

permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
Expand All @@ -28,31 +28,38 @@ jobs:
run_tests:
if: github.event.repository.fork == false
runs-on: ${{ matrix.runner }}
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
strategy:
matrix:
python: ["3.10", "3.13"]
python: ["3.11", "3.13"]
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
enable-x_64: [1, 0]
exclude:
# Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have
# coverage for one of each, we don't need to run both.
- python: "3.10"
- python: "3.11"
enable-x_64: 1
- python: "3.13"
enable-x_64: 0
name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# Only test a single Python version on Arm64 as we don't run the tests.
- python: "3.11"
runner: "linux-arm64-c4a-16"
name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# End Presubmit Naming Check github-cpu-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CPU Tests with RBE
# Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we
# cross-compile the tests on the Linux x86 RBE pool.
- name: ${{ (contains(matrix.runner, 'linux-arm64') && 'Build' || 'Run') }} Bazel CPU Tests with RBE
run: ./ci/run_bazel_test_cpu_rbe.sh
48 changes: 31 additions & 17 deletions .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,59 @@ on:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
jaxlib-version:
description: "Which jaxlib version to test? (head/pypi_latest)"
type: string
default: "head"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false

type: string
default: 'no'
permissions: {}
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
# Enable writing to the Bazel remote cache bucket.
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"

name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
name: "jaxlib=${{ inputs.jaxlib-version }},
${{ (contains(inputs.runner, 'h100') && 'h100') ||
(contains(inputs.runner, 'b200') && 'b200') ||
(contains(inputs.runner, 'l4') && 'l4') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}"

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
Expand All @@ -77,11 +81,21 @@ jobs:
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
run: |
mkdir -p $(pwd)/dist
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/

if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
PYTHON=python${{ inputs.python }}
$PYTHON -m pip download jaxlib jax-cuda12-pjrt jax-cuda12-plugin --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
exit 1
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
Expand All @@ -91,7 +105,7 @@ jobs:
exit 1
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA tests (Non-RBE)
Expand Down
Loading
Loading