Skip to content

Commit 9bcc434

Browse files
committed
Update GRPO workflow
1 parent 2bd3206 commit 9bcc434

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

.github/actions/gke-xpk/action.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ runs:
192192
args+=(
193193
--env-file="${{ inputs.ENV_FILE }}"
194194
)
195+
196+
echo "Setting the following environment variables in the ${JOBSET_NAME} JobSet froma ${{ inputs.ENV_FILE }}"
197+
cat ${{ inputs.ENV_FILE }}
195198
else
196199
args+=(
197200
--env="JAX_COORDINATOR_PORT=3389"

.github/workflows/jax-vllm-offloading-gke-grpo.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ jobs:
1414
runs-on: gke-a3mega
1515
strategy:
1616
matrix:
17-
model: ["meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"]
17+
model: ["meta-llama/Llama-3.1-8B-Instruct"]
1818
env:
1919
WORKLOAD_NAME_PREPREFIX: gke-jax-vllm-grpo
2020
JAX_VLLM_OFFLOADING_IMAGE: ${{ inputs.JAX_VLLM_OFFLOADING_IMAGE }}
2121

2222
NUM_NODES: 2
23-
ENV_FILE: ../../.github/gke-workflow/jax-vllm-offloading/transfer/jobset.env
23+
ENV_FILE: ../../.github/gke-workflow/jax-vllm-offloading/grpo/jobset.env
2424

2525
steps:
2626
- uses: actions/checkout@v4
@@ -69,8 +69,10 @@ jobs:
6969
python rollout.py 2>&1 | tee -a rollout.log &
7070
PIDS+=(\$!);
7171
else
72-
echo Starting trainer;
73-
python trainer.py 2>&1 | tee -a trainer.log &
72+
export MODEL_PATH=\$(python download_model.py --hub=hf --model=\${MODEL_NAME} --ignore='*.pth');
73+
74+
echo Starting GRPO trainer;
75+
python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log &
7476
PIDS+=(\$!);
7577
fi;
7678

0 commit comments

Comments
 (0)