From e7e847ade51c109a2d59e5bd4358a70a7ea63254 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Tue, 29 Jul 2025 20:36:39 +0000 Subject: [PATCH] Fix to trim the input to be divisible by num gpus in test shard UT --- tests/ffi_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index d6270d14e5d6..28415e4d5db4 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -276,10 +276,12 @@ def test_invalid_result_type(self): @jtu.run_on_devices("gpu", "cpu") def test_shard_map(self): - if jtu.is_device_rocm: - self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map") + # if jtu.is_device_rocm: + # self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map") mesh = jtu.create_mesh((len(jax.devices()),), ("i",)) x = self.rng().randn(8, 4, 5).astype(np.float32) + n = len(jax.devices()) + x = x[:(x.shape[0] // n) * n] @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i")) def f(x):