From f30ae9632a2cd9e10c1574381f99e83b2ebc4d39 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Tue, 17 Feb 2026 11:02:37 -0600 Subject: [PATCH 1/4] [Fix] Added HSDP functionality --- examples/llama/README.md | 6 ++++++ examples/llama/train_llama2.sh | 15 +++++++++++++++ examples/llama/train_llama3.sh | 15 +++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/examples/llama/README.md b/examples/llama/README.md index 49416ad6633..ef23a1d08bb 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -221,6 +221,12 @@ follow these steps for 2 Node run with Node0 as master node : - **MEGATRON_FSDP:** `1` to enable Megatron-LM's custom FSDP with DTensor checkpointing (default: 0). It adds automatically `--use-megatron-fsdp --ckpt-format fsdp_dtensor` in the script. Of note, this disables `TP>1` automatically. +- **ENABLE_HSDP:** + `1` to enable Hybrid Sharded Data Parallel (HSDP) mode (default: 0). Requires `MEGATRON_FSDP=1`. + +- **HSDP_NUM_DIST_OPT_INSTANCES:** + Number of distributed optimizer instances used by HSDP (default: 2). This maps to `--num-distributed-optimizer-instances`. + - **FP8_PARAM_GATHER:** Controls FP8 primary weights vs FP8 caches when `TE_FP8=1` (default: 0). Set to `1` to add `--fp8-param-gather` (weights kept in FP8, smaller all-gathers). Set to `0` to skip the `--fp8-param-gather` flag (weights stay BF16, FP8 caches are used for compute; FP8 weight transpose cache is still kept). diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh index 72bd1b12c25..fe36eb56d7c 100755 --- a/examples/llama/train_llama2.sh +++ b/examples/llama/train_llama2.sh @@ -87,6 +87,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-2}" TOKENIZER_TYPE="${TOKENIZER_TYPE:-HuggingFaceTokenizer}" if [ "$TOKENIZER_TYPE" == "Llama2Tokenizer" ]; then @@ -109,6 +111,11 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi +if [ "$ENABLE_HSDP" -eq 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: ENABLE_HSDP=1 requires MEGATRON_FSDP=1" + exit +fi + EXPERIMENT_DIR="experiment" mkdir -p $EXPERIMENT_DIR DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" @@ -367,6 +374,14 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" + if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -le 1 ]; then + echo "Error: HSDP_NUM_DIST_OPT_INSTANCES should be greater than 1 when ENABLE_HSDP is set to 1" + exit + else + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" + fi + fi fi if [ -n "${WANDB_API_KEY}" ]; then diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh index 2ffd8dda62e..0ccf5c172cf 100755 --- a/examples/llama/train_llama3.sh +++ b/examples/llama/train_llama3.sh @@ -88,6 +88,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-2}" if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then unset CUDA_DEVICE_MAX_CONNECTIONS @@ -98,6 +100,11 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi +if [ "$ENABLE_HSDP" -eq 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: ENABLE_HSDP=1 requires MEGATRON_FSDP=1" + exit +fi + EXPERIMENT_DIR="experiment" mkdir -p $EXPERIMENT_DIR DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" @@ -350,6 +357,14 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" + if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -le 1 ]; then + echo "Error: HSDP_NUM_DIST_OPT_INSTANCES should be greater than 1 when ENABLE_HSDP is set to 1" + exit + else + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" + fi + fi fi run_cmd=" From 4f5c06ddd4d5cef5cb5624cc0bf0c864bfe679c6 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Wed, 25 Feb 2026 11:04:08 -0600 Subject: [PATCH 2/4] [Fix] Simplified training scripts and HSDP enabling --- examples/llama/README.md | 5 +---- examples/llama/train_llama2.sh | 14 +++++--------- examples/llama/train_llama3.sh | 14 +++++--------- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/llama/README.md b/examples/llama/README.md index ef23a1d08bb..dd75c4da62d 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -221,11 +221,8 @@ follow these steps for 2 Node run with Node0 as master node : - **MEGATRON_FSDP:** `1` to enable Megatron-LM's custom FSDP with DTensor checkpointing (default: 0). It adds automatically `--use-megatron-fsdp --ckpt-format fsdp_dtensor` in the script. Of note, this disables `TP>1` automatically. -- **ENABLE_HSDP:** - `1` to enable Hybrid Sharded Data Parallel (HSDP) mode (default: 0). Requires `MEGATRON_FSDP=1`. - - **HSDP_NUM_DIST_OPT_INSTANCES:** - Number of distributed optimizer instances used by HSDP (default: 2). This maps to `--num-distributed-optimizer-instances`. + Number of distributed optimizer instances used by HSDP (default: 1). This maps to `--num-distributed-optimizer-instances`. If set to greater than `1` it enables Hybrid Sharded Data Parallel (HSDP) mode, which requires `MEGATRON_FSDP=1`. - **FP8_PARAM_GATHER:** Controls FP8 primary weights vs FP8 caches when `TE_FP8=1` (default: 0). Set to `1` to add `--fp8-param-gather` (weights kept in FP8, smaller all-gathers). Set to `0` to skip the `--fp8-param-gather` flag (weights stay BF16, FP8 caches are used for compute; FP8 weight transpose cache is still kept). diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh index fe36eb56d7c..51b895ad810 100755 --- a/examples/llama/train_llama2.sh +++ b/examples/llama/train_llama2.sh @@ -87,8 +87,7 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" -ENABLE_HSDP="${ENABLE_HSDP:-0}" -HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-2}" +HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-1}" TOKENIZER_TYPE="${TOKENIZER_TYPE:-HuggingFaceTokenizer}" if [ "$TOKENIZER_TYPE" == "Llama2Tokenizer" ]; then @@ -374,13 +373,10 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" - if [ "$ENABLE_HSDP" -eq 1 ]; then - if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -le 1 ]; then - echo "Error: HSDP_NUM_DIST_OPT_INSTANCES should be greater than 1 when ENABLE_HSDP is set to 1" - exit - else - EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" - fi + + if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ]; then + echo "HSDP is enabled with $HSDP_NUM_DIST_OPT_INSTANCES distributed optimizer instances" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" fi fi diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh index 0ccf5c172cf..485428bbb5e 100755 --- a/examples/llama/train_llama3.sh +++ b/examples/llama/train_llama3.sh @@ -88,8 +88,7 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" -ENABLE_HSDP="${ENABLE_HSDP:-0}" -HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-2}" +HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-1}" if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then unset CUDA_DEVICE_MAX_CONNECTIONS @@ -357,13 +356,10 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" - if [ "$ENABLE_HSDP" -eq 1 ]; then - if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -le 1 ]; then - echo "Error: HSDP_NUM_DIST_OPT_INSTANCES should be greater than 1 when ENABLE_HSDP is set to 1" - exit - else - EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" - fi + + if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ]; then + echo "HSDP is enabled with $HSDP_NUM_DIST_OPT_INSTANCES distributed optimizer instances" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" fi fi From 5a4b37deed8b6bcbffec3e092142967a8b867a83 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Thu, 26 Feb 2026 09:35:28 -0600 Subject: [PATCH 3/4] [Fix] Removed the ENABLE_HSDP flag --- examples/llama/train_llama2.sh | 4 ++-- examples/llama/train_llama3.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh index 51b895ad810..8cec58fdfa5 100755 --- a/examples/llama/train_llama2.sh +++ b/examples/llama/train_llama2.sh @@ -110,8 +110,8 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi -if [ "$ENABLE_HSDP" -eq 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then - echo "Error: ENABLE_HSDP=1 requires MEGATRON_FSDP=1" +if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP_NUM_DIST_OPT_INSTANCES>1 requires MEGATRON_FSDP=1" exit fi diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh index 485428bbb5e..49b7e9e7c99 100755 --- a/examples/llama/train_llama3.sh +++ b/examples/llama/train_llama3.sh @@ -99,8 +99,8 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi -if [ "$ENABLE_HSDP" -eq 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then - echo "Error: ENABLE_HSDP=1 requires MEGATRON_FSDP=1" +if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP_NUM_DIST_OPT_INSTANCES>1 requires MEGATRON_FSDP=1" exit fi From ee016244896f0ae0c7e8986dcfdedef4e3853696 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Mon, 2 Mar 2026 10:16:00 -0600 Subject: [PATCH 4/4] [Fix] Clarified usage of HSDP and num replicas for HSDP --- examples/llama/README.md | 13 +++++++++++-- examples/llama/train_llama2.sh | 24 ++++++++++++++++-------- examples/llama/train_llama3.sh | 24 ++++++++++++++++-------- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/examples/llama/README.md b/examples/llama/README.md index dd75c4da62d..81dc94b098d 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -162,6 +162,12 @@ TEE_OUTPUT=1 MBS=2 BS=16 TP=1 TE_FP8=0 FSDP=1 RECOMPUTE=1 SEQ_LENGTH=8192 MODEL_ And FSDP-v2 is not supported with pipeline parallelism, expert parallelism, MCore's distributed optimizer, gradient accumulation fusion and fp16. +To run with Megatron-LM HSDP enabled (Hybrid Sharded Data Parallel), use `ENABLE_HSDP=1` +and set the number of replicas with `HSDP_NUM_REPLICAS>=2`. For example: +```bash +TEE_OUTPUT=1 MBS=2 BS=16 TP=1 TE_FP8=0 ENABLE_HSDP=1 HSDP_NUM_REPLICAS=2 SEQ_LENGTH=8192 MODEL_SIZE=8 bash examples/llama/train_llama3.sh +``` + #### FP8 options with Megatron-LM FSDP (train_llama3.sh) - **FP8 primary weights (fp8_model_init, param gather):** `TE_FP8=1 MEGATRON_FSDP=1 FP8_PARAM_GATHER=1 bash examples/llama/train_llama3.sh` - **BF16 primary weights + FP8 caches (fp8_autocast):** `TE_FP8=1 MEGATRON_FSDP=1 FP8_PARAM_GATHER=0 bash examples/llama/train_llama3.sh` @@ -221,8 +227,11 @@ follow these steps for 2 Node run with Node0 as master node : - **MEGATRON_FSDP:** `1` to enable Megatron-LM's custom FSDP with DTensor checkpointing (default: 0). It adds automatically `--use-megatron-fsdp --ckpt-format fsdp_dtensor` in the script. Of note, this disables `TP>1` automatically. -- **HSDP_NUM_DIST_OPT_INSTANCES:** - Number of distributed optimizer instances used by HSDP (default: 1). This maps to `--num-distributed-optimizer-instances`. If set to greater than `1` it enables Hybrid Sharded Data Parallel (HSDP) mode, which requires `MEGATRON_FSDP=1`. +- **ENABLE_HSDP:** + `1` to enable Hybrid Sharded Data Parallel (HSDP) mode (default: 0). Requires `MEGATRON_FSDP=1`. + +- **HSDP_NUM_REPLICAS:** + Number of outer DP replicas used by HSDP (default: 2) when `ENABLE_HSDP=1`. This maps to `--num-distributed-optimizer-instances`. - **FP8_PARAM_GATHER:** Controls FP8 primary weights vs FP8 caches when `TE_FP8=1` (default: 0). Set to `1` to add `--fp8-param-gather` (weights kept in FP8, smaller all-gathers). Set to `0` to skip the `--fp8-param-gather` flag (weights stay BF16, FP8 caches are used for compute; FP8 weight transpose cache is still kept). diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh index 8cec58fdfa5..f9936d131d8 100755 --- a/examples/llama/train_llama2.sh +++ b/examples/llama/train_llama2.sh @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################### -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ################################################################################# @@ -87,7 +87,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" -HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-1}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_REPLICAS="${HSDP_NUM_REPLICAS:-2}" TOKENIZER_TYPE="${TOKENIZER_TYPE:-HuggingFaceTokenizer}" if [ "$TOKENIZER_TYPE" == "Llama2Tokenizer" ]; then @@ -110,9 +111,16 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi -if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then - echo "Error: HSDP_NUM_DIST_OPT_INSTANCES>1 requires MEGATRON_FSDP=1" - exit +if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP requires MEGATRON_FSDP=1" + exit 1 + fi + + if [ "$HSDP_NUM_REPLICAS" -lt 2 ]; then + echo "Error: HSDP_NUM_REPLICAS must be >= 2 when ENABLE_HSDP=1." + exit 1 + fi fi EXPERIMENT_DIR="experiment" @@ -374,9 +382,9 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" - if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ]; then - echo "HSDP is enabled with $HSDP_NUM_DIST_OPT_INSTANCES distributed optimizer instances" - EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" + if [ "$ENABLE_HSDP" -eq 1 ]; then + echo "Megatron HSDP is enabled with $HSDP_NUM_REPLICAS DP outer replicas" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_REPLICAS" fi fi diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh index 49b7e9e7c99..83fbc1390e2 100755 --- a/examples/llama/train_llama3.sh +++ b/examples/llama/train_llama3.sh @@ -1,6 +1,6 @@ #!/bin/bash ############################################################################### -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. ################################################################################# @@ -88,7 +88,8 @@ DATA_CACHE_PATH="${DATA_CACHE_PATH:-/root/cache}" MEGATRON_FSDP="${MEGATRON_FSDP:-0}" FP8_PARAM_GATHER="${FP8_PARAM_GATHER:-0}" FP8_TRANSPOSE_CACHE="${FP8_TRANSPOSE_CACHE:-0}" -HSDP_NUM_DIST_OPT_INSTANCES="${HSDP_NUM_DIST_OPT_INSTANCES:-1}" +ENABLE_HSDP="${ENABLE_HSDP:-0}" +HSDP_NUM_REPLICAS="${HSDP_NUM_REPLICAS:-2}" if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then unset CUDA_DEVICE_MAX_CONNECTIONS @@ -99,9 +100,16 @@ if [ "$FSDP" -eq 1 ] || [ "$MEGATRON_FSDP" -eq 1 ]; then fi fi -if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ] && [ "$MEGATRON_FSDP" -ne 1 ]; then - echo "Error: HSDP_NUM_DIST_OPT_INSTANCES>1 requires MEGATRON_FSDP=1" - exit +if [ "$ENABLE_HSDP" -eq 1 ]; then + if [ "$MEGATRON_FSDP" -ne 1 ]; then + echo "Error: HSDP requires MEGATRON_FSDP=1" + exit 1 + fi + + if [ "$HSDP_NUM_REPLICAS" -lt 2 ]; then + echo "Error: HSDP_NUM_REPLICAS must be >= 2 when ENABLE_HSDP=1." + exit 1 + fi fi EXPERIMENT_DIR="experiment" @@ -357,9 +365,9 @@ fi if [ "$MEGATRON_FSDP" -eq 1 ]; then EXTRA_ARGS="$EXTRA_ARGS --use-megatron-fsdp --ckpt-format fsdp_dtensor --data-parallel-sharding-strategy optim_grads_params --fsdp-double-buffer" - if [ "$HSDP_NUM_DIST_OPT_INSTANCES" -gt 1 ]; then - echo "HSDP is enabled with $HSDP_NUM_DIST_OPT_INSTANCES distributed optimizer instances" - EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_DIST_OPT_INSTANCES" + if [ "$ENABLE_HSDP" -eq 1 ]; then + echo "Megatron HSDP is enabled with $HSDP_NUM_REPLICAS DP outer replicas" + EXTRA_ARGS="$EXTRA_ARGS --num-distributed-optimizer-instances $HSDP_NUM_REPLICAS" fi fi