Skip to content

Commit 4b79f9e

Browse files
[mxfp8 moe training] update benchmarks to force load balancing (#3193)
1 parent bb3f03b commit 4b79f9e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

benchmarks/prototype/moe_training/bench_moe_layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
# this benchmark requires torchtitan
2525
try:
26-
from torchtitan.distributed.expert_parallel import (
26+
from torchtitan.models.moe import MoE, MoEArgs
27+
from torchtitan.models.moe.utils import (
2728
set_token_group_alignment_size_m,
2829
)
29-
from torchtitan.models.moe import MoE, MoEArgs
3030
except ImportError:
3131
logging.warning(
3232
"please pip install torchtitan to run this benchmark: https://github.com/pytorch/torchtitan"
@@ -77,6 +77,8 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
7777
target_fqns = ["experts"]
7878
model_args = MoEArgs(
7979
num_experts=local_num_experts,
80+
num_shared_experts=1,
81+
_debug_force_load_balance=True,
8082
)
8183
init_std = 0.02
8284
device = torch.device("cuda")

benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232

3333
# this benchmark requires torchtitan
3434
try:
35-
from torchtitan.distributed.expert_parallel import (
35+
from torchtitan.models.moe import MoE, MoEArgs
36+
from torchtitan.models.moe.utils import (
3637
set_token_group_alignment_size_m,
3738
)
38-
from torchtitan.models.moe import MoE, MoEArgs
3939
except ImportError:
4040
pytest.skip(
4141
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -71,6 +71,8 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
7171
target_fqns = ["experts"]
7272
model_args = MoEArgs(
7373
num_experts=16,
74+
num_shared_experts=1,
75+
_debug_force_load_balance=True,
7476
)
7577
init_std = 0.02
7678
device = torch.device("cuda")

0 commit comments

Comments
 (0)