feat(dflash): Add DFlash support#472
Conversation
5c4426d to
fab46ea
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fab46ea098
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if self.is_draft: | ||
| self.forward_decode_metadata = MHADecodeMetadata( | ||
| page_table=page_table, | ||
| seq_lens=seq_lens, | ||
| # DFLASH drafts a whole block in one forward; the decode | ||
| # kernel assumes query length 1, so expand each request | ||
| # into spec_num_tokens rows. Give every row the SAME full | ||
| # seq_len so each block query attends over the entire block | ||
| # (non-causal block-diffusion drafting). | ||
| self._fill_spec_metadata_uniform( |
There was a problem hiding this comment.
Gate draft metadata expansion to DFlash only
When an existing EAGLE3/MTP draft uses an MHA-family backend in eager decode (for example with CUDA graphs disabled or an uncaptured batch size), cuda_graph_wrapper._init_forward_metadata calls this DECODE path for the draft backend, and Eagle._run_multi_step_decode still sends only one query row per request (input_num_tokens=bs). This new if self.is_draft branch now expands the draft metadata to bs * spec_num_tokens rows for all draft models, not just DFlash, so the decode kernel receives page tables/seq_lens sized for many more rows than the q tensor and will either fail or attend with mismatched metadata. The expansion should be conditioned on draft_block_decode; legacy draft decode should keep the unexpanded page_table/seq_lens behavior.
Useful? React with 👍 / 👎.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 894f7c3d13
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| base_ctx.num_extends, | ||
| self.spec_num_tokens, | ||
| ) | ||
| return self.draft(current_tokens) |
There was a problem hiding this comment.
Avoid drafting during mid-chunk prefill
When every row is still in a chunked prefill (input_buffers.all_extends_mid_chunk), the DFlash block returned here is discarded by the scheduler until the prompt finishes, but self.draft(current_tokens) still writes spec_num_tokens draft KV rows starting at the current chunk end. For long prompts where those future positions cross pages the scheduler has not allocated yet, req_to_page still contains zeroes for the next block, so the draft write can overwrite page 0 in the draft cache and poison the prompt context used once prefill completes. The EAGLE path already avoids extra draft work in this exact mid-chunk case; DFlash should similarly skip the block draft or avoid writing draft KV until the request can transition to decode.
Useful? React with 👍 / 👎.
Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
894f7c3 to
ed5e0fc
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ed5e0fc1c1
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if residual is None: | ||
| hidden_states = self.norm(hidden_states) | ||
| else: | ||
| hidden_states, _ = self.norm(hidden_states, residual) |
There was a problem hiding this comment.
Reduce the final DFlash MLP output before norm
In dense tensor-parallel runs (--dense-tp-size >1, including the added Kimi DFlash launch), DFlashMLP.down_proj returns an unreduced row-parallel partial because it is constructed with reduce_results=False; intermediate layers get corrected by the next layer's input all-reduce, but after the last layer there is no next layer. This final norm therefore consumes only the local shard's MLP contribution and the draft sampler proposes tokens from incorrect hidden states, so add a final all-reduce (or reduce the last down_proj) before normalizing.
Useful? React with 👍 / 👎.
c23bb44 to
aa31274
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: aa31274cdc
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| cache_locs = self.draft_out_cache_loc_buf[: bs * self.spec_num_tokens] | ||
| compute_out_cache_loc_uniform( | ||
| out_cache_loc_ptr=cache_locs, | ||
| req_pool_indices=req_pool_indices, | ||
| uniform_input_length=self.spec_num_tokens, |
There was a problem hiding this comment.
Reserve enough slots for DFlash block writes
In non-overlap speculative decode, ScheduleBatch.prealloc_for_draft_decode reserves only spec_num_steps - 1 slots, and DFLASH enforces spec_num_steps == spec_num_tokens - 1; this path nevertheless computes cache locations for a full spec_num_tokens block starting at the accepted prefix. When a request accepts enough tokens near a page boundary, the tail of this block can land past the pages the scheduler reserved, so compute_out_cache_loc_uniform reads zero/unallocated req_to_page entries and writes draft KV into the wrong page. DFlash needs a matching full-block reservation (or a shorter write) before these locations are computed.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 733c23bf26
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| f"outputs. Got invalid ids {invalid}; valid range is " | ||
| f"[0, {num_layers - 2}] for {num_layers} target layers." | ||
| ) | ||
| self.model.layers_to_capture = {val + 1 for val in layer_ids} |
There was a problem hiding this comment.
Preserve DFlash target layer order
When a DFlash checkpoint declares dflash_config.target_layer_ids in a non-sorted order (for example [20, 4]), this conversion to a set loses the checkpoint’s requested order. DeepseekV3Model.forward appends captured auxiliary states in layer traversal order and the logits processor concatenates them in that appended order, so the tensor passed into the DFlash fc projection is sorted by layer instead of matching the checkpoint’s feature order, silently permuting the feature blocks and producing bad draft cache/tokens. Preserve the requested order or reject unsorted configs explicitly.
Useful? React with 👍 / 👎.
|
Thanks for carrying this forward. I saw this PR was created from #263 because the branch was not editable. Since it keeps my original commit/author and adds the requested test coverage, I'll track this PR and can help address remaining review/CI issues if useful. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4437969881
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if ctx.forward_mode.is_idle(): | ||
| hidden_states = self.mlp(hidden_states) | ||
| return hidden_states, residual |
There was a problem hiding this comment.
Mirror DFlash draft collectives on idle ranks
When DFLASH runs with data parallelism and one DP rank has no requests while another rank is drafting, execute_idle_forward calls the draft model with ForwardMode.IDLE; this branch skips attention and the active path's norm/all-reduce sites, and the only MLP it runs uses reduce_results=False, so the idle rank never enters the NCCL collectives that active DFlash layers enter later in this function. That mismatched collective sequence can hang the worker group; the idle path should issue the same zero-token reductions as the active draft path.
Useful? React with 👍 / 👎.
| language_model = getattr(target_model, "language_model", target_model) | ||
| self.target_model = target_model | ||
| self.target_language_model = language_model | ||
| self.embed_tokens = target_model.get_input_embeddings() |
There was a problem hiding this comment.
Fall back to target model embeddings when binding DFlash
When DFLASH is enabled for a target that advertises the new set_dflash_layers_to_capture support but does not implement get_input_embeddings (for example DeepseekV3ForCausalLM in this repo exposes model.embed_tokens/get_embed_and_head instead), ModelExecutor calls bind_target_model during startup and this line raises AttributeError before serving. Resolve the embedding module through the language model or model.embed_tokens fallback so direct DeepSeek/Kimi language-model targets can actually use the DFlash path.
Useful? React with 👍 / 👎.
Summary
Background
This PR is created based on #263. Thanks @mesaleh's contribution! However, since I was unable to push new code to his branch, I created a new PR to implement this functionality.
Acceptance rate verification
Accuracy verification
Test Plan