diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4421ea0da..ffb1cb75a 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -908,7 +908,9 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} - is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) + is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str( + self.__class__ + ) if is_deepseekv3_mtp_draft_model: special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index c6e7aa560..4cfd72e81 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -38,22 +38,29 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.batch_size // (args_mtp_step + 1) if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: page_buffer = FlashAttentionStateInfo.get_page_table_buffer( model.graph_max_batch_size, model.graph_max_len_in_batch ) self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) else: self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device + (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device ) + page_table_copy( page_table=self.page_table[:, :max_seq_len_k], req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx, + b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], ) + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + else: + self.b_att_seq_len = self.b_seq_len if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 8c6015677..ea44fe2e5 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -883,10 +883,10 @@ def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionS k_cache=cache_k, v_cache=cache_v, page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, + cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, + max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=sm_scale, causal=True, window_size=(-1, -1), diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..7e2f1a302 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -3,6 +3,7 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("qwen2") @@ -41,12 +42,20 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + self.mem_manager = select_mem_manager_class(self.mode)( self.max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/qwen3_moe_mtp/__init__.py b/lightllm/models/qwen3_moe_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py b/lightllm/models/qwen3_moe_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..57d98eec9 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,27 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Qwen3MOEMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # 与Qwen3MOE模型共享 + self.wte_weight_ = None + self.lm_head_weight_ = None + return + + def load_hf_weights(self, weights): + if "model.layers.0.proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.proj.weight"]).t() + if "model.layers.0.norm_after_embedding.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.norm_after_embedding.weight"]) + if "model.layers.0.norm_before_output.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.norm_before_output.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py new file mode 100644 index 000000000..ba6c82804 --- /dev/null +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -0,0 +1,47 @@ +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.common.basemodel import TpPartBaseModel + + +class Qwen3MOEMTPModel(Qwen3MOEModel): + + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + super()._init_weights() + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + self.pre_post_weight.final_norm_weight_ = self.main_model.pre_post_weight.final_norm_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..da8e095ef 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -35,6 +35,7 @@ from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -281,9 +282,12 @@ def init_mtp_draft_model(self, main_kvargs: dict): } mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) - assert mtp_model_cfg["model_type"] == "deepseek_v3" - assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" - self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + if mtp_model_cfg["model_type"] == "deepseekv3": + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "qwen3_moe": + self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) + else: + assert False, f"error mtp mode {mtp_model_cfg['model_type']}" self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return