Skip to content

Commit db1b64c

Browse files
hiworldwzjwangzaijun
andauthored
deepseek tpsp lora rank qkv all gather. (#1078)
Co-authored-by: wangzaijun <[email protected]>
1 parent 5128334 commit db1b64c

File tree

2 files changed

+97
-25
lines changed

2 files changed

+97
-25
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,13 @@ def _bind_ffn(self):
8282
moe_mode = os.environ.get("MOE_MODE", "TP")
8383
if moe_mode == "EP":
8484
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self)
85+
self._tpsp_ffn = self._tpsp_ffn_ep
8586
else:
8687
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
88+
self._tpsp_ffn = self._tpsp_ffn_tp
8789
else:
8890
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
91+
self._tpsp_ffn = self._tpsp_ffn_tp
8992

9093
def _bind_attention(self):
9194
if "triton_fp8kv" in self.mode:
@@ -187,23 +190,34 @@ def _get_qkv(
187190
def _tpsp_get_qkv(
188191
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
189192
) -> torch.Tensor:
190-
if self.tp_world_size_ > 1:
191-
sp_token_num, hidden_dim = input.shape
192-
gather_input = self.alloc_tensor(
193-
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
194-
)
195-
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
196-
input = gather_input[0 : len(infer_state.position_cos), :]
197-
198193
input = input.view(-1, self.embed_dim_)
199194
if self.q_lora_rank is None:
195+
# q_lora_rank is None 的时候,当前不支持低rank通信优化。
196+
if self.tp_world_size_ > 1:
197+
sp_token_num, hidden_dim = input.shape
198+
gather_input = self.alloc_tensor(
199+
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
200+
)
201+
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
202+
input = gather_input[0 : len(infer_state.position_cos), :]
203+
204+
input = input.view(-1, self.embed_dim_)
200205
q = layer_weight.q_weight_.mm(input)
201206
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
202207
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
203208
else:
204-
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
205-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
206-
)
209+
input = input.view(-1, self.embed_dim_)
210+
qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input)
211+
# 在 lora rank 之后,进行通信,可以减少通信量。
212+
if self.tp_world_size_ > 1:
213+
sp_token_num, qkv_dim = qkv.shape
214+
gather_qkv = self.alloc_tensor(
215+
(sp_token_num * self.tp_world_size_, qkv_dim), dtype=qkv.dtype, device=qkv.device
216+
)
217+
all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False)
218+
qkv = gather_qkv[0 : len(infer_state.position_cos), :]
219+
220+
q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1)
207221
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
208222
q = layer_weight.q_b_proj_.mm(q)
209223
cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
@@ -726,6 +740,43 @@ def _moe_ffn_edp(
726740
ep_output = ep_output.view(token_num, hidden_dim)
727741
return ep_output
728742

743+
def _tpsp_ffn(self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
744+
raise Exception("need bind to real impl")
745+
746+
def _tpsp_ffn_tp(
747+
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
748+
) -> torch.Tensor:
749+
input = input.view(-1, self.embed_dim_)
750+
if self.tp_world_size_ > 1:
751+
sp_token_num, hidden_dim = input.shape
752+
gather_input = self.alloc_tensor(
753+
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
754+
)
755+
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
756+
input = gather_input
757+
758+
ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight)
759+
760+
if self.tp_world_size_ > 1:
761+
sp_token_num = ffn2_out.shape[0] // self.tp_world_size_
762+
reduce_o_tensor = self.alloc_tensor(
763+
(sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device
764+
)
765+
reduce_scatter_tensor(
766+
reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False
767+
)
768+
ffn2_out = reduce_o_tensor
769+
return ffn2_out
770+
771+
def _tpsp_ffn_ep(
772+
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
773+
) -> torch.Tensor:
774+
input = input.view(-1, self.embed_dim_)
775+
776+
ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight)
777+
778+
return ffn2_out
779+
729780
def overlap_tpsp_token_forward(
730781
self,
731782
input_embdings: torch.Tensor,

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,41 @@ def _init_qkvo(self):
209209
)
210210

211211
def _load_mlp(self, mlp_prefix):
212-
self.gate_up_proj = MultiROWMMWeight(
213-
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
214-
data_type=self.data_type_,
215-
quant_cfg=self.quant_cfg,
216-
layer_num=self.layer_num_,
217-
name="gate_up_proj",
218-
)
219-
self.down_proj = COLMMWeight(
220-
weight_name=f"{mlp_prefix}.down_proj.weight",
221-
data_type=self.data_type_,
222-
quant_cfg=self.quant_cfg,
223-
layer_num=self.layer_num_,
224-
name="down_proj",
225-
)
212+
moe_mode = os.getenv("MOE_MODE", "TP")
213+
if self.is_moe and moe_mode == "EP":
214+
self.gate_up_proj = MultiROWMMWeight(
215+
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
216+
data_type=self.data_type_,
217+
quant_cfg=self.quant_cfg,
218+
layer_num=self.layer_num_,
219+
name="gate_up_proj",
220+
tp_rank=0,
221+
tp_world_size=1,
222+
)
223+
self.down_proj = COLMMWeight(
224+
weight_name=f"{mlp_prefix}.down_proj.weight",
225+
data_type=self.data_type_,
226+
quant_cfg=self.quant_cfg,
227+
layer_num=self.layer_num_,
228+
name="down_proj",
229+
tp_rank=0,
230+
tp_world_size=1,
231+
)
232+
else:
233+
self.gate_up_proj = MultiROWMMWeight(
234+
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
235+
data_type=self.data_type_,
236+
quant_cfg=self.quant_cfg,
237+
layer_num=self.layer_num_,
238+
name="gate_up_proj",
239+
)
240+
self.down_proj = COLMMWeight(
241+
weight_name=f"{mlp_prefix}.down_proj.weight",
242+
data_type=self.data_type_,
243+
quant_cfg=self.quant_cfg,
244+
layer_num=self.layer_num_,
245+
name="down_proj",
246+
)
226247

227248
def _init_moe(self):
228249
moe_intermediate_size = self.network_config_["moe_intermediate_size"]

0 commit comments

Comments
 (0)