-
Notifications
You must be signed in to change notification settings - Fork 620
[Bugfix] Resolve MTP > 1 issue when lm head tp > 1 #4254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an issue with speculative decoding (MTP) when tensor parallelism is used on the language model head. The core of the fix is to ensure the dummy run correctly simulates the multiple compute_logits calls that occur in a real run. While the fix is correctly applied for MtpProposer, it seems to be incomplete for EagleProposer, which could lead to the same issue in that scenario. Additionally, a refactoring in model_runner_v1.py appears to have introduced an AttributeError by calling a non-existent method on the drafter object. I've provided critical comments and suggestions for both issues.
| hidden_states[dummy_indices]) | ||
|
|
||
| def dummy_drafter_compute_logits(hidden_states): | ||
| return self.drafter.compute_logits( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dummy_drafter_compute_logits function calls self.drafter.compute_logits, but the compute_logits method is on the model attribute of the drafter object, not on the drafter itself. This will result in an AttributeError. The call should be self.drafter.model.compute_logits.
| return self.drafter.compute_logits( | |
| return self.drafter.model.compute_logits( |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
whx-sjtu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We firstly fixed hanging issue of running MTP=1 with llm head tp in PR #3915. This PR refactors it to run dummy_compute_logits in drafter's dummy_run and further fixes MTP > 1 scenario. LGTM.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
I've tested it on deepseek and it proves to be useful, please make ci happy |
28752d2 to
84253da
Compare
7b9f38b to
577c77b
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: Jade Zheng <[email protected]>
Signed-off-by: Jade Zheng <[email protected]>
Signed-off-by: Jade Zheng <[email protected]>
Signed-off-by: Jade Zheng <[email protected]>
Signed-off-by: Jade Zheng <[email protected]>
Signed-off-by: Jade Zheng <[email protected]>
577c77b to
0642908
Compare
| # sequence length to 1 to minimize their overheads in attention. | ||
| exceeds_max_model_len_cpu = exceeds_max_model_len.to( | ||
| attn_metadata_i.seq_lens.device, non_blocking=True) | ||
| attn_metadata_i.seq_lens.device, non_blocking=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jianzs I noticed you explain the reason why we disable non-blocking here. But IMO, the stream will keep the right order of data copy and the following operations in the same stream. I don't get the point on why there is an accuracy issue of this, is this a bug of torch-npu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In discuss offline, @jianzs mentioned that it is fine with h2d non-blocking copy, and the accuracy issue occurs with d2h copy. I think it might be a bug of torch-npu. Thus I'm fine with this change here as a workround, and we'll report this to torch-npu to finally fix it.
What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens.
I set the
non_blockingargument to False when movingexceeds_max_model_lento the CPU. From what I understand, usingnon_blocking=Trueand immediately accessing the tensor on the CPU can cause accuracy problems. However, this issue doesn't happen when transferring data to a device. ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/18Does this PR introduce any user-facing change?
No.
How was this patch tested?