@@ -82,10 +82,13 @@ def _bind_ffn(self):
8282 moe_mode = os .environ .get ("MOE_MODE" , "TP" )
8383 if moe_mode == "EP" :
8484 self ._ffn = partial (Deepseek2TransformerLayerInfer ._moe_ffn_edp , self )
85+ self ._tpsp_ffn = self ._tpsp_ffn_ep
8586 else :
8687 self ._ffn = partial (Deepseek2TransformerLayerInfer ._moe_ffn , self )
88+ self ._tpsp_ffn = self ._tpsp_ffn_tp
8789 else :
8890 self ._ffn = partial (LlamaTransformerLayerInfer ._ffn , self )
91+ self ._tpsp_ffn = self ._tpsp_ffn_tp
8992
9093 def _bind_attention (self ):
9194 if "triton_fp8kv" in self .mode :
@@ -187,23 +190,34 @@ def _get_qkv(
187190 def _tpsp_get_qkv (
188191 self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
189192 ) -> torch .Tensor :
190- if self .tp_world_size_ > 1 :
191- sp_token_num , hidden_dim = input .shape
192- gather_input = self .alloc_tensor (
193- (sp_token_num * self .tp_world_size_ , hidden_dim ), dtype = input .dtype , device = input .device
194- )
195- all_gather_into_tensor (gather_input , input , group = infer_state .dist_group , async_op = False )
196- input = gather_input [0 : len (infer_state .position_cos ), :]
197-
198193 input = input .view (- 1 , self .embed_dim_ )
199194 if self .q_lora_rank is None :
195+ # q_lora_rank is None 的时候,当前不支持低rank通信优化。
196+ if self .tp_world_size_ > 1 :
197+ sp_token_num , hidden_dim = input .shape
198+ gather_input = self .alloc_tensor (
199+ (sp_token_num * self .tp_world_size_ , hidden_dim ), dtype = input .dtype , device = input .device
200+ )
201+ all_gather_into_tensor (gather_input , input , group = infer_state .dist_group , async_op = False )
202+ input = gather_input [0 : len (infer_state .position_cos ), :]
203+
204+ input = input .view (- 1 , self .embed_dim_ )
200205 q = layer_weight .q_weight_ .mm (input )
201206 cache_kv = self ._pre_cache_kv (infer_state = infer_state , layer_weight = layer_weight )
202207 layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
203208 else :
204- q , cache_kv = layer_weight .qkv_a_proj_with_mqa_ .mm (input ).split (
205- [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
206- )
209+ input = input .view (- 1 , self .embed_dim_ )
210+ qkv = layer_weight .qkv_a_proj_with_mqa_ .mm (input )
211+ # 在 lora rank 之后,进行通信,可以减少通信量。
212+ if self .tp_world_size_ > 1 :
213+ sp_token_num , qkv_dim = qkv .shape
214+ gather_qkv = self .alloc_tensor (
215+ (sp_token_num * self .tp_world_size_ , qkv_dim ), dtype = qkv .dtype , device = qkv .device
216+ )
217+ all_gather_into_tensor (gather_qkv , qkv , group = infer_state .dist_group , async_op = False )
218+ qkv = gather_qkv [0 : len (infer_state .position_cos ), :]
219+
220+ q , cache_kv = qkv .split ([self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1 )
207221 q = rmsnorm_forward (q , weight = layer_weight .q_a_layernorm_ .weight , eps = self .eps_ )
208222 q = layer_weight .q_b_proj_ .mm (q )
209223 cache_kv = cache_kv .view (- 1 , 1 , self .kv_lora_rank + self .qk_rope_head_dim )
@@ -726,6 +740,43 @@ def _moe_ffn_edp(
726740 ep_output = ep_output .view (token_num , hidden_dim )
727741 return ep_output
728742
743+ def _tpsp_ffn (self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight ):
744+ raise Exception ("need bind to real impl" )
745+
746+ def _tpsp_ffn_tp (
747+ self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
748+ ) -> torch .Tensor :
749+ input = input .view (- 1 , self .embed_dim_ )
750+ if self .tp_world_size_ > 1 :
751+ sp_token_num , hidden_dim = input .shape
752+ gather_input = self .alloc_tensor (
753+ (sp_token_num * self .tp_world_size_ , hidden_dim ), dtype = input .dtype , device = input .device
754+ )
755+ all_gather_into_tensor (gather_input , input , group = infer_state .dist_group , async_op = False )
756+ input = gather_input
757+
758+ ffn2_out = self ._ffn (input = input , infer_state = infer_state , layer_weight = layer_weight )
759+
760+ if self .tp_world_size_ > 1 :
761+ sp_token_num = ffn2_out .shape [0 ] // self .tp_world_size_
762+ reduce_o_tensor = self .alloc_tensor (
763+ (sp_token_num , self .embed_dim_ ), dtype = ffn2_out .dtype , device = ffn2_out .device
764+ )
765+ reduce_scatter_tensor (
766+ reduce_o_tensor , ffn2_out , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False
767+ )
768+ ffn2_out = reduce_o_tensor
769+ return ffn2_out
770+
771+ def _tpsp_ffn_ep (
772+ self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
773+ ) -> torch .Tensor :
774+ input = input .view (- 1 , self .embed_dim_ )
775+
776+ ffn2_out = self ._ffn (input = input , infer_state = infer_state , layer_weight = layer_weight )
777+
778+ return ffn2_out
779+
729780 def overlap_tpsp_token_forward (
730781 self ,
731782 input_embdings : torch .Tensor ,
0 commit comments