diff --git a/examples/llama/README.md b/examples/llama/README.md index 49416ad6633..81dc94b098d 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -162,6 +162,12 @@ TEE_OUTPUT=1 MBS=2 BS=16 TP=1 TE_FP8=0 FSDP=1 RECOMPUTE=1 SEQ_LENGTH=8192 MODEL_ And FSDP-v2 is not supported with pipeline parallelism, expert parallelism, MCore's distributed optimizer, gradient accumulation fusion and fp16. +To run with Megatron-LM HSDP enabled (Hybrid Sharded Data Parallel), use `ENABLE_HSDP=1` +and set the number of replicas with `HSDP_NUM_REPLICAS>=2`. For example: +```bash +TEE_OUTPUT=1 MBS=2 BS=16 TP=1 TE_FP8=0 ENABLE_HSDP=1 HSDP_NUM_REPLICAS=2 SEQ_LENGTH=8192 MODEL_SIZE=8 bash examples/llama/train_llama3.sh +``` + #### FP8 options with Megatron-LM FSDP (train_llama3.sh) - **FP8 primary weights (fp8_model_init, param gather):** `TE_FP8=1 MEGATRON_FSDP=1 FP8_PARAM_GATHER=1 bash examples/llama/train_llama3.sh` - **BF16 primary weights + FP8 caches (fp8_autocast):** `TE_FP8=1 MEGATRON_FSDP=1 FP8_PARAM_GATHER=0 bash examples/llama/train_llama3.sh` @@ -221,6 +227,12 @@ follow these steps for 2 Node run with Node0 as master node : - **MEGATRON_FSDP:** `1` to enable Megatron-LM's custom FSDP with DTensor checkpointing (default: 0). It adds automatically `--use-megatron-fsdp --ckpt-format fsdp_dtensor` in the script. Of note, this disables `TP>1` automatically. +- **ENABLE_HSDP:** + `1` to enable Hybrid Sharded Data Parallel (HSDP) mode (default: 0). Requires `MEGATRON_FSDP=1`. + +- **HSDP_NUM_REPLICAS:** + Number of outer DP replicas used by HSDP (default: 2) when `ENABLE_HSDP=1`. This maps to `--num-distributed-optimizer-instances`. + - **FP8_PARAM_GATHER:** Controls FP8 primary weights vs FP8 caches when `TE_FP8=1` (default: 0). Set to `1` to add `--fp8-param-gather` (weights kept in FP8, smaller all-gathers). Set to `0` to skip the `--fp8-param-gather` flag (weights stay BF16, FP8 caches are used for compute; FP8 weight transpose cache is still kept). diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh index 72bd1b12c25..f9936d131d8 100755 --- a/examples/llama/train_llama2.sh +++ b/examples/llama/train_llama2.sh @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################### -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ################################################################################# @@ -87,6 +87,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_REPLICAS="${HSDP_NUM_REPLICAS:-2}" TOKENIZER_TYPE="${TOKENIZER_TYPE:-HuggingFaceTokenizer}" if [ "$TOKENIZER_TYPE" == "Llama2Tokenizer" ]; then @@ -109,6 +111,18 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi +if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP requires MEGATRON_FSDP=1" + exit 1 + fi + + if [ "$HSDP_NUM_REPLICAS" -lt 2 ]; then + echo "Error: HSDP_NUM_REPLICAS must be >= 2 when ENABLE_HSDP=1." + exit 1 + fi +fi + EXPERIMENT_DIR="experiment" mkdir -p $EXPERIMENT_DIR DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" @@ -367,6 +381,11 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" + + if [ "$ENABLE_HSDP" -eq 1 ]; then + echo "Megatron HSDP is enabled with $HSDP_NUM_REPLICAS DP outer replicas" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_REPLICAS" + fi fi if [ -n "${WANDB_API_KEY}" ]; then diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh index 2ffd8dda62e..83fbc1390e2 100755 --- a/examples/llama/train_llama3.sh +++ b/examples/llama/train_llama3.sh @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################### -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ################################################################################# @@ -88,6 +88,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_REPLICAS="${HSDP_NUM_REPLICAS:-2}" if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then unset CUDA_DEVICE_MAX_CONNECTIONS @@ -98,6 +100,18 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi +if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP requires MEGATRON_FSDP=1" + exit 1 + fi + + if [ "$HSDP_NUM_REPLICAS" -lt 2 ]; then + echo "Error: HSDP_NUM_REPLICAS must be >= 2 when ENABLE_HSDP=1." + exit 1 + fi +fi + EXPERIMENT_DIR="experiment" mkdir -p $EXPERIMENT_DIR DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" @@ -350,6 +364,11 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" + + if [ "$ENABLE_HSDP" -eq 1 ]; then + echo "Megatron HSDP is enabled with $HSDP_NUM_REPLICAS DP outer replicas" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_REPLICAS" + fi fi run_cmd="