Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6db8716
Add k8s JAX-vLLM offloading example
aybchan Nov 24, 2025
b398510
Update gateway URL
aybchan Nov 24, 2025
e0a1b67
Add two-node manifest
aybchan Nov 24, 2025
485e2a3
Add 8:8 logs
aybchan Nov 24, 2025
fd9c38f
Add hard-coded 2x node jobset example
aybchan Nov 25, 2025
0000997
patch vLLM weight loader
yhtang Nov 25, 2025
82fe0ce
bump tunix version
yhtang Nov 25, 2025
a3f36ff
Add jax[k8s] extras to install
aybchan Nov 25, 2025
e069f1b
Organize deployment manifests
aybchan Nov 26, 2025
aa926cb
Set missing env. vars
aybchan Nov 26, 2025
5591ce4
address PR comments
yhtang Nov 26, 2025
d45fa3a
address PR comments
yhtang Nov 26, 2025
771f97d
Remove debug trace
aybchan Nov 26, 2025
db4861b
Add JAX-vLLM workflow
aybchan Nov 26, 2025
20802b8
Fix JobSet command
aybchan Nov 26, 2025
794ff86
Add xpk patch, update env file, patch composite action
aybchan Nov 27, 2025
bc9d877
Enable image pull secret set
aybchan Nov 27, 2025
b64763f
Set jobset dot env path
aybchan Nov 27, 2025
f8cd259
Refactor CI workflows
aybchan Nov 27, 2025
b75f09f
Fix workflow
aybchan Nov 27, 2025
ed9d8b0
Fix workflow
aybchan Nov 27, 2025
8c85fe7
Fix workflow
aybchan Nov 27, 2025
19391d6
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan Nov 27, 2025
7e5e1ce
Add build to pipeline
aybchan Nov 27, 2025
2ab8e10
Set test build
aybchan Nov 27, 2025
77d6c82
Disable arm64 build for now
aybchan Nov 27, 2025
78800df
Update deprecataed variable
aybchan Nov 28, 2025
a081257
revert change related to a vllm tokenizer load failure that is no lon…
yhtang Nov 28, 2025
8b3440f
remove tunix version pin
yhtang Nov 28, 2025
4372984
revert JAX version bump
yhtang Nov 28, 2025
b8c7b9c
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan Nov 28, 2025
989cb24
Merge workflow definitions
aybchan Nov 28, 2025
e07922c
Set run 70B transfer
aybchan Nov 28, 2025
0cc6920
Set arm64 and amd64 build separately
aybchan Nov 28, 2025
bd49355
Fix artifact name collision
aybchan Nov 28, 2025
ffa6706
Update keyword name due to tunix 974da5
aybchan Nov 28, 2025
f111e2f
Make xpk composite action changes backwards compatible
aybchan Nov 28, 2025
2bd3206
Add working k8s GRPO recipe
aybchan Nov 28, 2025
18950d7
Update GRPO workflow
aybchan Nov 28, 2025
df32d3b
Fix workflow
aybchan Nov 28, 2025
bdaafd0
Fix inline command
aybchan Nov 28, 2025
aba0842
Remove debug logs
aybchan Nov 28, 2025
4f7c0ec
Set consume step output
aybchan Nov 28, 2025
a1f5a4b
Set workload name strictly lower case
aybchan Nov 28, 2025
9c4b2bd
Handle invalid jobset name
aybchan Nov 29, 2025
db46141
Remove unnecessary unbound variable
aybchan Nov 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion .github/actions/gke-xpk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ inputs:
required: false
default: 'nvidia-smi; free -h;'
type: string
ENV_FILE:
description: 'Environment variable file to pass to xpk for setting in JobSet'
required: false
default: ''
type: string
EXIT_COMMAND:
description: 'Command to set exit code'
required: false
Expand Down Expand Up @@ -178,11 +183,24 @@ runs:
}

if version_greater "${{ inputs.XPK_VERSION }}" "v0.10.0"; then
args+=(
--docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }}
)

# --env is incompatible with --env-var in xpk
if [ -e "${{ inputs.ENV_FILE }}" ]; then
args+=(
--env-file="${{ inputs.ENV_FILE }}"
)

echo "Setting the following environment variables in the ${WORKLOAD_NAME} JobSet from the env. file at ${{ inputs.ENV_FILE }} "
cat ${{ inputs.ENV_FILE }}
else
args+=(
--docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }}
--env="JAX_COORDINATOR_PORT=3389"
--env="JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):3389"
)
fi
fi

python xpk.py workload create \
Expand Down
7 changes: 7 additions & 0 deletions .github/gke-workflow/jax-vllm-offloading/deploy-transfer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kubectl apply -f transfer/deployment/gateway-pod.yml
kubectl apply -f transfer/deployment/gateway-svc.yml

kubectl apply -f huggingface-secret.yml

kubectl apply -f transfer/deployment/rollout.yml
kubectl apply -f transfer/deployment/trainer.yml
22 changes: 22 additions & 0 deletions .github/gke-workflow/jax-vllm-offloading/grpo/jobset.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
CUDA_DEVICE_ORDER=PCI_BUS_ID
CUDA_DEVICE_MAX_CONNECTIONS=16
VLLM_ENFORCE_EAGER=1
VLLM_GPU_MEMORY_UTILIZATION=0.7
VLLM_TENSOR_PARALLEL_SIZE=8
VLLM_DISTRIBUTED_BACKEND=mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1
VLLM_LOAD_FORMAT=dummy
NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
NCCL_TUNER_PLUGIN=none
MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
NCCL_CUMEM_ENABLE=0
NCCL_BUFFSIZE=16777216
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592
TRANSFER_MODE=grouped
USE_POLYMORPHIC_MESH=0
JAX_COORDINATOR_PORT=3389
JAX_COORDINATOR_ADDRESS=$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):$(JAX_COORDINATOR_PORT)
GATEWAY_PORT=50051
GATEWAY_URL=$(JOBSET_NAME):$(GATEWAY_PORT)
OUTPUT_DIR=/opt/output
280 changes: 280 additions & 0 deletions .github/gke-workflow/jax-vllm-offloading/grpo/jobset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
annotations:
name: jax-vllm-grpo
namespace: default
spec:
network:
enableDNSHostnames: true
publishNotReadyAddresses: true
replicatedJobs:
- name: slice-job
replicas: 1
template:
metadata: {}
spec:
backoffLimit: 0
completionMode: Indexed
completions: 2
parallelism: 2
template:
metadata:
annotations:
devices.gke.io/container.tcpxo-daemon: |
- path: /dev/nvidia0
- path: /dev/nvidia1
- path: /dev/nvidia2
- path: /dev/nvidia3
- path: /dev/nvidia4
- path: /dev/nvidia5
- path: /dev/nvidia6
- path: /dev/nvidia7
- path: /dev/nvidiactl
- path: /dev/nvidia-uvm
- path: /dev/dmabuf_import_helper
networking.gke.io/default-interface: eth0
networking.gke.io/interfaces: |-
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"},
{"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"},
{"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"},
{"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"},
{"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"},
{"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"},
{"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"},
{"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"}
]
spec:
imagePullSecrets:
- name: jax-toolbox-ghcr
containers:
- name: gpu-image
image: ghcr.io/nvidia/jax-toolbox-internal:19751502075-jio-amd64
imagePullPolicy: Always
command:
- bash
- -c
- |
pip install jax[k8s]
python -c "
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())"

PIDS=()
# hard-code split of vLLM-JAX on 1x node each on 2x slice jobset
if [ ${NODE_RANK} = "0" ]; then
echo "Starting gateway"
cd /opt/jtbx/jax-inference-offloading
python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log &
PIDS+=($!)

echo "Starting rollout"
cd /opt/jtbx/jax-inference-offloading/examples
python rollout.py 2>&1 | tee -a rollout.log &
PIDS+=($!)
else
echo "Starting trainer"
export MODEL_PATH=$(python "download_model.py" --hub=hf --model=${MODEL_NAME} --ignore="*.pth")
python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log &
PIDS+=($!)
fi

wait "${PIDS[@]}"
echo "All done"
env:
# jobset
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: JOBSET_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
- name: NODE_RANK
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
- name: USE_GPUDIRECT
value: tcpxo
- name: GPUS_PER_NODE
value: "8"

- name: LD_LIBRARY_PATH
value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64"

# huggingface
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
- name: MODEL_NAME
value: "meta-llama/Llama-3.1-8B-Instruct"
- name: SCRATCHDIR
value: "/opt/scratch"

# gateway
- name: GATEWAY_PORT
value: "50051"
- name: GATEWAY_URL
value: "$(JOBSET_NAME):$(GATEWAY_PORT)"

# JAX
- name: JAX_COORDINATOR_PORT
value: "3389"
- name: JAX_COORDINATOR_ADDRESS
value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389

# CUDA
- name: CUDA_VISIBLE_DEVICES
value: "0,1,2,3,4,5,6,7"
- name: CUDA_DEVICE_ORDER
value: "PCI_BUS_ID"
- name: CUDA_DEVICE_MAX_CONNECTIONS
value: "16"

# vLLM
- name: VLLM_ENFORCE_EAGER
value: "1"
- name: VLLM_GPU_MEMORY_UTILIZATION
value: "0.7"
- name: VLLM_TENSOR_PARALLEL_SIZE
value: "8"
- name: VLLM_DISTRIBUTED_BACKEND
value: "mp"
- name: VLLM_ATTENTION_BACKEND
value: "TRITON_ATTN"
- name: VLLM_LOAD_FORMAT
value: "dummy"

# NCCL
- name: NCCL_NET_PLUGIN
value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so"
- name: NCCL_TUNER_PLUGIN
value: "none"
- name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY
value: /dev/aperture_devices
- name: NCCL_CUMEM_ENABLE
value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues
- name: NCCL_BUFFSIZE
value: "16777216"

# XLA
- name: XLA_PYTHON_CLIENT_MEM_FRACTION
value: "0.95"
- name: XLA_FLAGS
value: "--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL
--xla_gpu_collective_permute_combine_threshold_bytes=8589934592
--xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
--xla_gpu_all_gather_combine_threshold_bytes=8589934592
--xla_gpu_all_reduce_combine_threshold_bytes=8589934592"

# trainer
- name: TRANSFER_MODE
value: "grouped"
- name: USE_POLYMORPHIC_MESH
value: "0"
- name: JAX_COMPILATION_CACHE_DIR
value: /opt/jax-compilation
- name: JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS
value: "0.1"
- name: RUN_MODE
value: "timing"
- name: ROLLOUT_ENGINE
value: "vllm_gpu"
- name: GRPO_TRAIN_MICRO_BATCH_SIZE
value: "2"


ports:
- containerPort: 50051
protocol: TCP
- containerPort: 3389
protocol: TCP
resources:
limits:
nvidia.com/gpu: "8"
securityContext:
privileged: true
volumeMounts:
- mountPath: /dev/aperture_devices
name: aperture-devices
- mountPath: /usr/local/nvidia
name: libraries
- mountPath: /dev/shm
name: dshm
- mountPath: /opt/scratch
name: scratch
dnsPolicy: ClusterFirstWithHostNet
initContainers:
- args:
- |-
set -ex
chmod 755 /fts/entrypoint_rxdm_container.sh
/fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr
command:
- /bin/sh
- -c
env:
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64
image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12
imagePullPolicy: Always
name: tcpxo-daemon
resources: {}
restartPolicy: Always
securityContext:
capabilities:
add:
- NET_ADMIN
- NET_BIND_SERVICE
volumeMounts:
- mountPath: /usr/local/nvidia
name: libraries
- mountPath: /hostsysfs
name: sys
- mountPath: /hostprocsysfs
name: proc-sys
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb
priorityClassName: high
terminationGracePeriodSeconds: 30
tolerations:
- key: nvidia.com/gpu
operator: Exists
- effect: NoSchedule
key: user-workload
operator: Equal
value: "true"
volumes:
- hostPath:
path: /home/kubernetes/bin/nvidia
name: libraries
- hostPath:
path: /sys
name: sys
- hostPath:
path: /proc/sys
name: proc-sys
- hostPath:
path: /dev/aperture_devices
name: aperture-devices
- emptyDir:
medium: Memory
name: dshm
- emptyDir:
sizeLimit: 2Gi
name: scratch
startupPolicy:
startupPolicyOrder: AnyOrder
successPolicy:
operator: All
ttlSecondsAfterFinished: 100000
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
apiVersion: v1
kind: Secret
metadata:
name: hf-token-secret
namespace: default
type: Opaque
stringData:
token: {{ HF_TOKEN}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
apiVersion: v1
kind: Pod
metadata:
name: jax-vllm-gateway
namespace: default
labels:
app: jax-vllm-gateway
spec:
imagePullSecrets:
- name: jax-toolbox-ghcr
containers:
- name: jax-vllm-gateway-server
image: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64
workingDir: /opt/jtbx/jax-inference-offloading
command: ["python", "jax_inference_offloading/controller/gateway.py"]
volumeMounts:
- mountPath: /dev/shm
name: shmem
env:
- name: GATEWAY_PORT
value: "50051"
ports:
- containerPort: 50051

volumes:
- name: output
emptyDir: {}
- name: shmem
emptyDir:
medium: Memory

# schedule on GPU node (but don't request GPU resource)
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"


Loading
Loading