Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
with set_ascend_forward_context(None,
self.vllm_config,
Expand All @@ -134,6 +135,7 @@ def dummy_run(self,
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
dummy_compute_logits(self.hidden_states)

def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
Expand Down
7 changes: 5 additions & 2 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None) -> None:

(
num_tokens,
Expand Down Expand Up @@ -298,6 +299,7 @@ def dummy_run(self,
self.update_stream, forward_context,
positions.shape[0],
self.vllm_config.speculative_config)
dummy_compute_logits(previous_hidden_states)
if with_prefill:
break

Expand Down Expand Up @@ -756,6 +758,7 @@ def _propose(
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices]
draft_token_ids = logits.argmax(dim=-1)

if self.num_speculative_tokens == 1:
Expand Down Expand Up @@ -821,7 +824,7 @@ def _propose(
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens.device, non_blocking=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianzs I noticed you explain the reason why we disable non-blocking here. But IMO, the stream will keep the right order of data copy and the following operations in the same stream. I don't get the point on why there is an accuracy issue of this, is this a bug of torch-npu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In discuss offline, @jianzs mentioned that it is fine with h2d non-blocking copy, and the accuracy issue occurs with d2h copy. I think it might be a bug of torch-npu. Thus I'm fine with this change here as a workround, and we'll report this to torch-npu to finally fix it.

attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def dummy_run(self,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
pass

def generate_token_ids(self,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/torchair/torchair_mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None) -> None:
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)

if not with_prefill:
Expand Down Expand Up @@ -142,6 +143,7 @@ def dummy_run(self,
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
dummy_compute_logits(previous_hidden_states)
if with_prefill:
break

Expand Down
32 changes: 18 additions & 14 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2963,14 +2963,21 @@ def _dummy_run(

need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())

if need_dummy_logits:
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

def dummy_compute_logits(hidden_states):
return self.model.compute_logits(
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

def dummy_compute_logits(hidden_states):
if not need_dummy_logits:
return None
return self.model.compute_logits(hidden_states[dummy_indices])

def dummy_drafter_compute_logits(hidden_states):
if not need_dummy_logits or self.drafter is None:
return
if hasattr(self.drafter, "model") and hasattr(
self.drafter.model, "compute_logits"):
return self.drafter.model.compute_logits(
hidden_states[dummy_indices])

with set_ascend_forward_context(
Expand All @@ -2992,8 +2999,7 @@ def dummy_compute_logits(hidden_states):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
dummy_compute_logits(hidden_states)

if self.drafter:
self.drafter.dummy_run(
Expand All @@ -3002,10 +3008,8 @@ def dummy_compute_logits(hidden_states):
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
self.drafter.model.compute_logits(
hidden_states[dummy_indices])
batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits)
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
Expand Down
Loading