Skip to content

Commit dad0388

Browse files
committed
[Bugfix] Resolve MTP > 1 issue when lm head tp > 1
Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: Jade Zheng <[email protected]>
1 parent e985432 commit dad0388

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def dummy_run(self,
136136
num_reqs: int = 0,
137137
num_tokens_across_dp: Optional[torch.Tensor] = None,
138138
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
139-
batch_descriptor=None):
139+
batch_descriptor=None,
140+
dummy_compute_logits=lambda hidden_states: None):
140141
moe_comm_type = self.runner._select_moe_comm_method(
141142
num_tokens, with_prefill)
142143
with set_ascend_forward_context(None,
@@ -148,6 +149,7 @@ def dummy_run(self,
148149
positions=self.positions[:num_tokens],
149150
hidden_states=self.hidden_states[:num_tokens],
150151
)
152+
dummy_compute_logits(self.hidden_states)
151153

152154
def generate_token_ids(self,
153155
valid_sampled_token_ids: list[list[int]],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def dummy_run(self,
211211
num_reqs: int = 0,
212212
num_tokens_across_dp=None,
213213
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
214-
batch_descriptor=None) -> None:
214+
batch_descriptor=None,
215+
dummy_compute_logits=lambda hidden_states: None) -> None:
215216

216217
(
217218
num_tokens,
@@ -243,6 +244,7 @@ def dummy_run(self,
243244
self.model(input_ids=input_ids,
244245
positions=positions,
245246
hidden_states=previous_hidden_states)
247+
dummy_compute_logits(previous_hidden_states)
246248
if with_prefill:
247249
break
248250

@@ -665,6 +667,7 @@ def _propose(
665667
logits = self.model.compute_logits(sample_hidden_states)
666668
if lmhead_tp_enable() and num_indices < logits.shape[0]:
667669
logits = logits[:num_indices]
670+
last_token_indices = last_token_indices[:num_indices]
668671
draft_token_ids = logits.argmax(dim=-1)
669672

670673
if self.num_speculative_tokens == 1:
@@ -721,7 +724,7 @@ def _propose(
721724
# For the requests that exceed the max model length, we set the
722725
# sequence length to 1 to minimize their overheads in attention.
723726
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
724-
attn_metadata_i.seq_lens.device, non_blocking=True)
727+
attn_metadata_i.seq_lens.device, non_blocking=False)
725728
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
726729
exceeds_max_model_len_cpu, 1)
727730
# Mask out the slot mappings that exceed the max model length.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,6 +3049,14 @@ def dummy_compute_logits(hidden_states):
30493049
return self.model.compute_logits(
30503050
hidden_states[dummy_indices])
30513051

3052+
def dummy_drafter_compute_logits(hidden_states):
3053+
return self.drafter.compute_logits(
3054+
hidden_states[dummy_indices])
3055+
3056+
else:
3057+
dummy_compute_logits = lambda hidden_states: None
3058+
dummy_drafter_compute_logits = lambda hidden_states: None
3059+
30523060
with set_ascend_forward_context(
30533061
attn_metadata,
30543062
self.vllm_config,
@@ -3068,8 +3076,7 @@ def dummy_compute_logits(hidden_states):
30683076
with_prefill, is_torchair_compile, input_ids, positions,
30693077
attn_metadata, num_tokens, intermediate_tensors,
30703078
inputs_embeds)
3071-
if need_dummy_logits:
3072-
dummy_compute_logits(hidden_states)
3079+
dummy_compute_logits(hidden_states)
30733080

30743081
if self.drafter:
30753082
self.drafter.dummy_run(
@@ -3079,10 +3086,8 @@ def dummy_compute_logits(hidden_states):
30793086
num_reqs=num_reqs,
30803087
num_tokens_across_dp=num_tokens_across_dp,
30813088
aclgraph_runtime_mode=aclgraph_runtime_mode,
3082-
batch_descriptor=batch_descriptor)
3083-
if need_dummy_logits:
3084-
self.drafter.model.compute_logits(
3085-
hidden_states[dummy_indices])
3089+
batch_descriptor=batch_descriptor,
3090+
dummy_compute_logits=dummy_drafter_compute_logits)
30863091
if self.in_profile_run and self.dynamic_eplb:
30873092
self.model.clear_all_moe_loads()
30883093
if not self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)