Skip to content

Conversation

@aybchan
Copy link
Member

@aybchan aybchan commented Nov 24, 2025

JAX-vLLM offloading transfer and GRPO examples on Kubernetes

Benchmark results

Transfer

Cluster commit job definition Nodes vLLM GPUs JAX GPUs Debug on create_bridge load_model transfer
GKE (k8s) e0a1b67 deploy.sh 1 2 2 Yes 1145.220 3.617 7.425
GKE (k8s) e0a1b67 deploy.sh 1 4 4 Yes 1058.517 6.129 2.442
GKE (k8s) e0a1b67 deploy.sh 2 8 8 Yes 1226.945 9.237 2.618
GKE (k8s, TCPXO) fd9c38f jobset.yml 2 8 8 Yes 958.822 9.423 1.298
GKE (k8s, TCPXO) 771f97d jobset.yml 2 8 8 No 70.989 9.317 1.202
viking-prod (slurm) 03c29c6 example-transfer-multinode.sh 2 8 8 Yes 83.273 8.279 10.496
eos (slurm) 03c29c6 example-transfer-multinode.sh 2 8 8 No 68.618 8.133 0.1876

GRPO

Cluster commit job definition Nodes vLLM GPUs JAX GPUs handshake rollout training
eos (slurm) e0a1b67 example-grpo-multinode.sh 1 4 4 9.7 3.6 65.2
GKE (k8s, TCPXO) 2bd3206 jobset.yaml 2 8 8 40.5 10.1 98.3
eos (slurm) e0a1b67 example-grpo-multinode.sh 2 8 8 - - -

n.b.

  • all devices are H100, model used is meta-llama/Llama-3.1-8B-Instruct
  • 2-node Kubernets examples use multiprocessing instead of ray backend for vLLM since inference is entirely on one node
  • A note regarding GKE performance - our workflow for deploying the transfer and GRPO workloads are based on using xpk as with the existing MaxText and NCCL tests that are run on GKE. With this setu, NCCL tests run on the same cluster achieve the same bus bandwidth (all gather 16GB at 189GB/s) as documented in the GCP documentation for maximizing GPU network bandwidth (all gather 16GB at 189GB/s)
  • For GRPO, train micro-batch size is reduced from 8 to two due to memory constraints on H100 GPUs
  • [jax-inference-offloading] Upgrade vLLM to 0.11.2 #1799 is merged into this branch to make use of the working image build

CI workflow

The CI workflow added in this PR handles the building of amd64 and arm64 images which are then used to run the transfer and GRPO k8s recipe workloads on GKE.

The workloads are created on the cluster with the xpk toolkit which creates a JobSet resource. This is done using the xpk-gke composite action that is already used for NCCL and MaxText workloads on GKE.

For the transfer recipe, meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-3.1-70B-Instruct are run. For the GRPO recipe, only meta-llama/Llama-3.1-8B-Instruct runs due to memory constraints on the particular GPU in the GKE cluster.

example run

Appendix: Detailed results

Transfer

Single-node 2:2 (a3-megagpu-8g - H100)
Timer summary (tree view), seconds
-------------------------------------
create_bridge              : 1145.220
  handshake                : 1145.220
jax-distributed-initialize :    1.007
load_model                 :    3.617
transport                  :   37.124
  run0                     :    7.511
    to_named_parameters    :    0.000
    transfer               :    7.511
  run1                     :    7.358
    to_named_parameters    :    0.000
    transfer               :    7.358
  run2                     :    7.445
    to_named_parameters    :    0.000
    transfer               :    7.445
  run3                     :    7.375
    to_named_parameters    :    0.000
    transfer               :    7.375
  run4                     :    7.434
    to_named_parameters    :    0.000
    transfer               :    7.434
warmup                     :    7.861
  to_named_parameters      :    0.000
  transfer                 :    7.861

Single-node 4:4 (a3-megagpu-8g - H100)
Timer summary (tree view), seconds
-------------------------------------
create_bridge              : 1058.517
  handshake                : 1058.517
jax-distributed-initialize :    1.882
load_model                 :    6.129
transport                  :   12.210
  run0                     :    2.463
    to_named_parameters    :    0.000
    transfer               :    2.463
  run1                     :    2.406
    to_named_parameters    :    0.000
    transfer               :    2.406
  run2                     :    2.435
    to_named_parameters    :    0.000
    transfer               :    2.435
  run3                     :    2.456
    to_named_parameters    :    0.000
    transfer               :    2.456
  run4                     :    2.450
    to_named_parameters    :    0.000
    transfer               :    2.450
warmup                     :    4.895
  to_named_parameters      :    0.000
  transfer                 :    4.894
2-node 8:8 (a3-megagpu-8g - H100) As at e0a1b67
Timer summary (tree view), seconds
-------------------------------------
create_bridge              : 1226.945
  handshake                : 1226.945
jax-distributed-initialize :    4.011
load_model                 :    9.237
transport                  :   13.091
  run0                     :    2.629
    to_named_parameters    :    0.000
    transfer               :    2.629
  run1                     :    2.621
    to_named_parameters    :    0.000
    transfer               :    2.620
  run2                     :    2.625
    to_named_parameters    :    0.000
    transfer               :    2.625
  run3                     :    2.602
    to_named_parameters    :    0.000
    transfer               :    2.602
  run4                     :    2.615
    to_named_parameters    :    0.000
    transfer               :    2.615
warmup                     :   10.731
  to_named_parameters      :    0.000
  transfer                 :   10.731
2-node 8:8 JobSet with TCPXO plugin enabled (a3-megagpu-8g - H100)

JobSet as at fd9c38f

Timer summary (tree view), seconds
------------------------------------
create_bridge              : 958.822
  handshake                : 958.822
jax-distributed-initialize :   1.914
load_model                 :   9.423
transport                  :   6.488
  run0                     :   1.325
    to_named_parameters    :   0.000
    transfer               :   1.325
  run1                     :   1.275
    to_named_parameters    :   0.000
    transfer               :   1.275
  run2                     :   1.316
    to_named_parameters    :   0.000
    transfer               :   1.316
  run3                     :   1.260
    to_named_parameters    :   0.000
    transfer               :   1.260
  run4                     :   1.312
    to_named_parameters    :   0.000
    transfer               :   1.312
warmup                     :   8.774
  to_named_parameters      :   0.000
  transfer                 :   8.774
2-node 8:8 JobSet with TCPXO plugin no debug (a3-megagpu-8g - H100)

JobSet as at fd9c38f

Timer summary (tree view), seconds
-----------------------------------
create_bridge              : 70.989
  handshake                : 70.989
jax-distributed-initialize :  3.541
load_model                 :  9.317
transport                  :  6.011
  run0                     :  1.175
    to_named_parameters    :  0.000
    transfer               :  1.175
  run1                     :  1.186
    to_named_parameters    :  0.000
    transfer               :  1.186
  run2                     :  1.248
    to_named_parameters    :  0.000
    transfer               :  1.248
  run3                     :  1.155
    to_named_parameters    :  0.000
    transfer               :  1.155
  run4                     :  1.248
    to_named_parameters    :  0.000
    transfer               :  1.248
warmup                     : 44.594
  to_named_parameters      :  0.000
  transfer                 : 44.594
2-node 8:8 slurm (viking-prod - H100)
0: Timer summary (tree view), seconds
0: -----------------------------------
0: create_bridge              : 83.273
0:   handshake                : 83.273
0: jax-distributed-initialize :  5.697
0: load_model                 :  8.279
0: transport                  : 52.479
0:   run0                     : 10.495
0:     to_named_parameters    :  0.000
0:     transfer               : 10.495
0:   run1                     : 10.489
0:     to_named_parameters    :  0.000
0:     transfer               : 10.489
0:   run2                     : 10.503
0:     to_named_parameters    :  0.000
0:     transfer               : 10.503
0:   run3                     : 10.485
0:     to_named_parameters    :  0.000
0:     transfer               : 10.485
0:   run4                     : 10.507
0:     to_named_parameters    :  0.000
0:     transfer               : 10.507
0: warmup
0:                      : 14.261
0:   to_named_parameters      :  0.000
0:   transfer                 : 14.261
2-node 8:8 slurm no debug (eos - H100)
0: Timer summary (tree view), seconds
0: -----------------------------------
0: create_bridge              : 68.618
0:   handshake                : 68.618
0: jax-distributed-initialize :  5.833
0: load_model                 :  8.133
0: transport                  :  0.938
0:   run0                     :  0.190
0:     to_named_parameters    :  0.000
0:     transfer               :  0.190
0:   run1                     :  0.187
0:     to_named_parameters    :  0.000
0:     transfer               :  0.187
0:   run2                     :  0.189
0:     to_named_parameters    :  0.000
0:     transfer               :  0.189
0:   run3                     :  0.185
0:     to_named_parameters    :  0.000
0:     transfer               :  0.185
0:   run4                     :  0.187
0:     to_named_parameters    :  0.000
0:     transfer               :  0.187

GRPO

2-node 8:8 JobSet with TCPXO plugin enabled (a3-megagpu-8g - H100)
Timer summary (tree view), seconds
--------------------------------
handshake                 : 40.5
load_checkpoint           : 15.3
  model                   : 15.1
  tokenizer               :  0.3
load_dataset              :  2.3
rollout                   : 10.1
  update_params           : 10.1
    to_named_parameters   :  0.0
    transfer              : 10.1
training                  : 98.3
  rollout                 : 48.9
    generate              : 42.7
      inference           : 42.7
      process_outputs     :  0.0
    update_params         :  6.3
      to_named_parameters :  0.0
      transfer            :  6.3
Single-node 4:4 slurm (eos - H100)
Timer summary (tree view), seconds
--------------------------------
handshake                 :  9.7
load_checkpoint           : 10.8
  model                   : 10.6
  tokenizer               :  0.2
load_dataset              :  2.0
rollout                   :  3.6
  update_params           :  3.6
    to_named_parameters   :  0.0
    transfer              :  3.6
training                  : 65.2
  rollout                 : 19.4
    generate              : 19.1
      inference           : 19.1
      process_outputs     :  0.0
    update_params         :  0.3
      to_named_parameters :  0.0
      transfer            :  0.3

@aybchan aybchan self-assigned this Nov 24, 2025
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch from b45cada to e0a1b67 Compare November 24, 2025 17:53
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch 2 times, most recently from 1b95e99 to a3f36ff Compare November 26, 2025 00:32
- bash
- -c
- |
pip install jax[k8s]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this is meant to bring in the k8s extra?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, without it the processes across the pods will not discover each other in jax.distributed.initialize() - I added the dependency to the Dockerfile

@@ -0,0 +1,266 @@
apiVersion: jobset.x-k8s.io/v1alpha2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both the jobset specs and the separate deployment specs? Are they meant for different test scenarios?

Copy link
Member Author

@aybchan aybchan Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manifests in transfer/deployment are from an older recipe - jobset.yml is the latest working recipe that I'm working from which includes GPUDirect RDMA plugin use

@aybchan aybchan changed the title JAX-vLLM Offloading k8s/GKE manifest JAX-vLLM Offloading k8s (GKE) Nov 26, 2025
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch from 64c2abc to aeb70b0 Compare November 27, 2025 11:57
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch 3 times, most recently from 4742877 to 4730a76 Compare November 27, 2025 12:10
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch from 4730a76 to b75f09f Compare November 27, 2025 12:11
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch from 4560bd6 to 7e5e1ce Compare November 27, 2025 16:07
@aybchan aybchan force-pushed the aybchan/jax-vllm-offloading-k8s branch from 9bcc434 to 18950d7 Compare November 28, 2025 21:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants