From f1b49d86ff47fbae3dc22e62707727766edec268 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 18 Nov 2025 17:50:25 +0800 Subject: [PATCH 1/6] [Bugfix] Resolve MTP > 1 issue when lm head tp > 1 Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: Jade Zheng --- vllm_ascend/spec_decode/eagle_proposer.py | 4 +++- vllm_ascend/spec_decode/mtp_proposer.py | 7 +++++-- vllm_ascend/worker/model_runner_v1.py | 17 +++++++++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 4d076ac117f..3c276af748a 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -123,7 +123,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): moe_comm_type = self.runner._select_moe_comm_method(num_tokens) with set_ascend_forward_context(None, self.vllm_config, @@ -134,6 +135,7 @@ def dummy_run(self, positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], ) + dummy_compute_logits(self.hidden_states) def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 556a917fc4a..8cc9d8f01cb 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -215,7 +215,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None) -> None: + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: ( num_tokens, @@ -298,6 +299,7 @@ def dummy_run(self, self.update_stream, forward_context, positions.shape[0], self.vllm_config.speculative_config) + dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -756,6 +758,7 @@ def _propose( logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] draft_token_ids = logits.argmax(dim=-1) if self.num_speculative_tokens == 1: @@ -821,7 +824,7 @@ def _propose( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=True) + attn_metadata_i.seq_lens.device, non_blocking=False) attn_metadata_i.seq_lens[:batch_size].masked_fill_( exceeds_max_model_len_cpu, 1) # Mask out the slot mappings that exceed the max model length. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5677550bbe4..6b91f4ef837 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2973,6 +2973,14 @@ def dummy_compute_logits(hidden_states): return self.model.compute_logits( hidden_states[dummy_indices]) + def dummy_drafter_compute_logits(hidden_states): + return self.drafter.compute_logits( + hidden_states[dummy_indices]) + + else: + dummy_compute_logits = lambda hidden_states: None + dummy_drafter_compute_logits = lambda hidden_states: None + with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -2992,8 +3000,7 @@ def dummy_compute_logits(hidden_states): with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - if need_dummy_logits: - dummy_compute_logits(hidden_states) + dummy_compute_logits(hidden_states) if self.drafter: self.drafter.dummy_run( @@ -3002,10 +3009,8 @@ def dummy_compute_logits(hidden_states): num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor) - if need_dummy_logits: - self.drafter.model.compute_logits( - hidden_states[dummy_indices]) + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: From d3246486f70beb88ffa6efb00b230c912b50f80e Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 18 Nov 2025 18:17:08 +0800 Subject: [PATCH 2/6] update Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6b91f4ef837..7480e5ac19e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2974,7 +2974,7 @@ def dummy_compute_logits(hidden_states): hidden_states[dummy_indices]) def dummy_drafter_compute_logits(hidden_states): - return self.drafter.compute_logits( + return self.drafter.model.compute_logits( hidden_states[dummy_indices]) else: From d2beb9c26b0b058866652a6538ed842abfcbc3d6 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 21 Nov 2025 23:14:08 +0800 Subject: [PATCH 3/6] update Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 29 +++++++++++++-------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7480e5ac19e..be8d975a938 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2963,24 +2963,23 @@ def _dummy_run( need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) - - if need_dummy_logits: - max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs - dummy_indices = torch.zeros(max_num_reqs_across_dp, - dtype=torch.int32) - - def dummy_compute_logits(hidden_states): - return self.model.compute_logits( - hidden_states[dummy_indices]) - - def dummy_drafter_compute_logits(hidden_states): + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits: + return + if hasattr(self.drafter, "model") and hasattr( + self.drafter.model, "compute_logits"): return self.drafter.model.compute_logits( hidden_states[dummy_indices]) - else: - dummy_compute_logits = lambda hidden_states: None - dummy_drafter_compute_logits = lambda hidden_states: None - with set_ascend_forward_context( attn_metadata, self.vllm_config, From 684a0a28364e5ebc0c62894c103884282c88567e Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 21 Nov 2025 23:40:40 +0800 Subject: [PATCH 4/6] update Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index be8d975a938..49a16204789 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2973,7 +2973,7 @@ def dummy_compute_logits(hidden_states): return self.model.compute_logits(hidden_states[dummy_indices]) def dummy_drafter_compute_logits(hidden_states): - if not need_dummy_logits: + if not need_dummy_logits or self.drafter is None: return if hasattr(self.drafter, "model") and hasattr( self.drafter.model, "compute_logits"): From 6f7f8441b6b8320667da6386503bdec46bf3eae9 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 23 Nov 2025 12:04:33 +0800 Subject: [PATCH 5/6] update Signed-off-by: Jade Zheng --- vllm_ascend/torchair/torchair_mtp_proposer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/torchair/torchair_mtp_proposer.py b/vllm_ascend/torchair/torchair_mtp_proposer.py index b816b8d8412..7df6592fe4d 100644 --- a/vllm_ascend/torchair/torchair_mtp_proposer.py +++ b/vllm_ascend/torchair/torchair_mtp_proposer.py @@ -80,7 +80,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None) -> None: + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: moe_comm_type = self.runner._select_moe_comm_method(num_tokens) if not with_prefill: @@ -142,6 +143,7 @@ def dummy_run(self, self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + dummy_compute_logits(previous_hidden_states) if with_prefill: break From 0642908eb6a9dd74513fcaf6ba4be2563913c6c3 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 23 Nov 2025 20:04:21 +0800 Subject: [PATCH 6/6] update Signed-off-by: Jade Zheng --- vllm_ascend/spec_decode/ngram_proposer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 932a127cf01..63b2711a32e 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -26,7 +26,8 @@ def dummy_run(self, num_reqs=None, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): pass def generate_token_ids(self,