Skip to content

Commit fab8954

Browse files
committed
before judge moe_type
1 parent 6c95469 commit fab8954

2 files changed

Lines changed: 24 additions & 15 deletions

File tree

dlinfer/vendor/ascend/moe.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
import torch
33
import numpy
44
import torch.distributed as dist
5+
6+
from enum import Enum
57
from dlinfer.utils.type_annotation import DlinferDistContext
68
from dlinfer.vendor.ascend.utils import SocVersion, get_world_size_accros_dp
7-
from dlinfer.framework.lmdeploy_ext.device.ascend import get_pad_size
9+
from dlinfer.framework.lmdeploy_ext.device.ascend import (
10+
get_max_tokens_accros_dp,
11+
get_pad_size,
12+
)
813
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_graph_params
914

1015

11-
class MoEType:
12-
ALL2ALL: str = "all2all"
13-
MC2: str = "mc2"
14-
ALLGAHER: str = "allgaher"
15-
UNDEFINED: str = "undefined"
16+
class MoEType(Enum):
17+
ALLGATHER = 0
18+
MC2 = 1
19+
ALL2ALL = 2
20+
NAIVE_MULTICAST = 3
1621

1722

1823
def mc2_tokens_capacity(dist_ctx: DlinferDistContext) -> int:
@@ -23,7 +28,8 @@ def inner(tp_size: int) -> int:
2328
max_num_tokens = max(graph_params.handles.keys())
2429
else:
2530
# NOTE: To save memory, we cap the max number of tokens to 512.
26-
max_num_tokens = 512
31+
max_num_tokens = min(get_max_tokens_accros_dp(), 512)
32+
max_num_tokens = min(max_num_tokens, 512)
2733
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
2834
return num_tokens_per_tp_rank * tp_size
2935

@@ -32,23 +38,25 @@ def inner(tp_size: int) -> int:
3238

3339
def select_moe_type(num_tokens: int, dist_ctx: DlinferDistContext) -> str:
3440
if dist_ctx.ep_size <= 1:
35-
return MoEType.ALLGAHER
41+
moe_type = MoEType.ALLGATHER
3642
elif SocVersion.is_A2():
3743
if (
3844
num_tokens <= mc2_tokens_capacity(dist_ctx)
3945
and get_world_size_accros_dp(dist_ctx) >= 16
4046
):
41-
return MoEType.MC2
47+
moe_type = MoEType.MC2
4248
else:
43-
return MoEType.ALLGAHER
49+
moe_type = MoEType.ALLGATHER
4450
elif SocVersion.is_A3():
4551
if num_tokens <= mc2_tokens_capacity(dist_ctx):
46-
return MoEType.MC2
52+
moe_type = MoEType.MC2
4753
else:
48-
return MoEType.ALLGAHER
54+
moe_type = MoEType.ALL2ALL
4955
else:
5056
raise ValueError(f"Unsupported soc_version: {SocVersion.soc_version()}")
5157

58+
return moe_type
59+
5260

5361
def apply_mlp(
5462
hidden_states: torch.Tensor,

dlinfer/vendor/ascend/torch_npu_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,8 @@ def fused_moe(
607607
gate_up_weights = gate_up_weights.transpose(1, 2)
608608
down_weights = down_weights.transpose(1, 2)
609609

610-
# if moe.select_moe_type(num_tokens, dist_ctx) == moe.MoEType.ALLGAHER:
610+
# moe_type = moe.select_moe_type(num_tokens, dist_ctx)
611+
# if moe_type == moe.MoEType.ALLGATHER:
611612
if dist_ctx.ep_size <= 1:
612613
moe_output = moe.fused_moe_allgaher(
613614
hidden_states,
@@ -618,7 +619,7 @@ def fused_moe(
618619
topk,
619620
renormalize,
620621
)
621-
# elif moe.select_moe_type(num_tokens, dist_ctx) == moe.MoEType.MC2:
622+
# elif moe_type == moe.MoEType.MC2:
622623
elif AscendOpsBackend.max_tokens_accros_dp <= dist_ctx.tp_size * 512:
623624
moe_output = moe.fused_moe_mc2(
624625
hidden_states,
@@ -632,7 +633,7 @@ def fused_moe(
632633
moe_group_name,
633634
x_active_mask,
634635
)
635-
# elif moe.select_moe_type(num_tokens, dist_ctx) == moe.MoEType.ALL2ALL:
636+
# elif moe_type == moe.MoEType.ALL2ALL:
636637
else:
637638
moe_output = moe.fused_moe_all2all(
638639
hidden_states,

0 commit comments

Comments
 (0)