22import torch
33import numpy
44import torch .distributed as dist
5+
6+ from enum import Enum
57from dlinfer .utils .type_annotation import DlinferDistContext
68from dlinfer .vendor .ascend .utils import SocVersion , get_world_size_accros_dp
7- from dlinfer .framework .lmdeploy_ext .device .ascend import get_pad_size
9+ from dlinfer .framework .lmdeploy_ext .device .ascend import (
10+ get_max_tokens_accros_dp ,
11+ get_pad_size ,
12+ )
813from dlinfer .framework .lmdeploy_ext .cudagraph .ascend_cudagraph import get_graph_params
914
1015
11- class MoEType :
12- ALL2ALL : str = "all2all"
13- MC2 : str = "mc2"
14- ALLGAHER : str = "allgaher"
15- UNDEFINED : str = "undefined"
16+ class MoEType ( Enum ) :
17+ ALLGATHER = 0
18+ MC2 = 1
19+ ALL2ALL = 2
20+ NAIVE_MULTICAST = 3
1621
1722
1823def mc2_tokens_capacity (dist_ctx : DlinferDistContext ) -> int :
@@ -23,7 +28,8 @@ def inner(tp_size: int) -> int:
2328 max_num_tokens = max (graph_params .handles .keys ())
2429 else :
2530 # NOTE: To save memory, we cap the max number of tokens to 512.
26- max_num_tokens = 512
31+ max_num_tokens = min (get_max_tokens_accros_dp (), 512 )
32+ max_num_tokens = min (max_num_tokens , 512 )
2733 num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
2834 return num_tokens_per_tp_rank * tp_size
2935
@@ -32,23 +38,25 @@ def inner(tp_size: int) -> int:
3238
3339def select_moe_type (num_tokens : int , dist_ctx : DlinferDistContext ) -> str :
3440 if dist_ctx .ep_size <= 1 :
35- return MoEType .ALLGAHER
41+ moe_type = MoEType .ALLGATHER
3642 elif SocVersion .is_A2 ():
3743 if (
3844 num_tokens <= mc2_tokens_capacity (dist_ctx )
3945 and get_world_size_accros_dp (dist_ctx ) >= 16
4046 ):
41- return MoEType .MC2
47+ moe_type = MoEType .MC2
4248 else :
43- return MoEType .ALLGAHER
49+ moe_type = MoEType .ALLGATHER
4450 elif SocVersion .is_A3 ():
4551 if num_tokens <= mc2_tokens_capacity (dist_ctx ):
46- return MoEType .MC2
52+ moe_type = MoEType .MC2
4753 else :
48- return MoEType .ALLGAHER
54+ moe_type = MoEType .ALL2ALL
4955 else :
5056 raise ValueError (f"Unsupported soc_version: { SocVersion .soc_version ()} " )
5157
58+ return moe_type
59+
5260
5361def apply_mlp (
5462 hidden_states : torch .Tensor ,
0 commit comments