Skip to content

Commit 160b555

Browse files
yiming0416githubsgi
authored andcommitted
Minor fixes in simple_fsdp experiments (pytorch#1853)
pytorch#1850 removed `name` field in `TrainSpec`. The experiments in simple_fsdp should also be updated. Otherwise it won't run. pytorch#1776 added `use_flex_attn` field to `apply_non_moe_tp()`, which is missing in simple_fsdp experiments ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable ``` ``` NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable ```
1 parent f63037f commit 160b555

File tree

3 files changed

+2
-2
lines changed

3 files changed

+2
-2
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
def get_train_spec() -> TrainSpec:
2323
return TrainSpec(
24-
name="simple_fsdp.deepseek_v3",
2524
model_cls=SimpleFSDPDeepSeekV3Model,
2625
model_args=deepseekv3_configs,
2726
parallelize_fn=parallelize_deepseekv3,

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def parallelize_deepseekv3(
5555
"Currently, float8 tensorwise TP is not tested for deepseekv3"
5656
)
5757

58+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
5859
apply_non_moe_tp(
5960
model,
6061
world_mesh["tp"],
6162
loss_parallel=not job_config.parallelism.disable_loss_parallel,
6263
enable_float8_tensorwise_tp=False,
64+
use_flex_attn=use_flex_attn,
6365
)
6466
maybe_enable_async_tp(job_config, world_mesh["tp"])
6567

torchtitan/experiments/simple_fsdp/llama3/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
def get_train_spec() -> TrainSpec:
2222
return TrainSpec(
23-
name="simple_fsdp.llama3",
2423
model_cls=SimpleFSDPTransformer,
2524
model_args=llama3_configs,
2625
parallelize_fn=parallelize_llama,

0 commit comments

Comments
 (0)