-
Notifications
You must be signed in to change notification settings - Fork 66
JAX-vLLM Offloading k8s (GKE) #1797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b45cada to
e0a1b67
Compare
1b95e99 to
a3f36ff
Compare
.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml
Show resolved
Hide resolved
.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml
Show resolved
Hide resolved
.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml
Show resolved
Hide resolved
.github/gke-workflow/jax-vllm-offloading/transfer/deployment/trainer.yml
Show resolved
Hide resolved
| - bash | ||
| - -c | ||
| - | | ||
| pip install jax[k8s] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that this is meant to bring in the k8s extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, without it the processes across the pods will not discover each other in jax.distributed.initialize() - I added the dependency to the Dockerfile
| @@ -0,0 +1,266 @@ | |||
| apiVersion: jobset.x-k8s.io/v1alpha2 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have both the jobset specs and the separate deployment specs? Are they meant for different test scenarios?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manifests in transfer/deployment are from an older recipe - jobset.yml is the latest working recipe that I'm working from which includes GPUDirect RDMA plugin use
64c2abc to
aeb70b0
Compare
4742877 to
4730a76
Compare
4730a76 to
b75f09f
Compare
4560bd6 to
7e5e1ce
Compare
9bcc434 to
18950d7
Compare
JAX-vLLM offloading transfer and GRPO examples on Kubernetes
Benchmark results
Transfer
create_bridgeload_modeltransferdeploy.shdeploy.shdeploy.shjobset.ymljobset.ymlexample-transfer-multinode.shexample-transfer-multinode.shGRPO
handshakerollouttrainingexample-grpo-multinode.shjobset.yamlexample-grpo-multinode.shn.b.
meta-llama/Llama-3.1-8B-InstructCI workflow
The CI workflow added in this PR handles the building of
amd64andarm64images which are then used to run the transfer and GRPO k8s recipe workloads on GKE.The workloads are created on the cluster with the xpk toolkit which creates a JobSet resource. This is done using the xpk-gke composite action that is already used for NCCL and MaxText workloads on GKE.
For the transfer recipe,
meta-llama/Llama-3.1-8B-Instructandmeta-llama/Llama-3.1-70B-Instructare run. For the GRPO recipe, onlymeta-llama/Llama-3.1-8B-Instructruns due to memory constraints on the particular GPU in the GKE cluster.example run
Appendix: Detailed results
Transfer
Single-node 2:2 (a3-megagpu-8g - H100)
Single-node 4:4 (a3-megagpu-8g - H100)
2-node 8:8 (a3-megagpu-8g - H100)
As at e0a1b672-node 8:8 JobSet with TCPXO plugin enabled (a3-megagpu-8g - H100)
JobSet as at fd9c38f
2-node 8:8 JobSet with TCPXO plugin no debug (a3-megagpu-8g - H100)
JobSet as at fd9c38f
2-node 8:8 slurm (viking-prod - H100)
2-node 8:8 slurm no debug (eos - H100)
GRPO
2-node 8:8 JobSet with TCPXO plugin enabled (a3-megagpu-8g - H100)
Single-node 4:4 slurm (eos - H100)