Skip to content

feat(dflash): Add DFlash support#472

Merged
lightseek-bot merged 6 commits into
mainfrom
yweng/dflash_support
Jun 25, 2026
Merged

feat(dflash): Add DFlash support#472
lightseek-bot merged 6 commits into
mainfrom
yweng/dflash_support

Conversation

@yweng0828

@yweng0828 yweng0828 commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Summary

Background

This PR is created based on #263. Thanks @mesaleh's contribution! However, since I was unable to push new code to his branch, I created a new PR to implement this functionality.

Acceptance rate verification

  • attn_tp4_moe_tp4
  • B200
  • dataset: swe_smith
                                         Detailed Performance Metrics                                         
┏━━━━━━┳━━━━━━┳━━━━━┳━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓
┃      ┃      ┃     ┃      ┃     Avg ┃     P99 ┃      Avg ┃      P99 ┃      Avg ┃      P99 ┃   Gen. ┃ Success┃
┃Conc. ┃ Rate ┃ Num ┃  RPS ┃ Lat.(s) ┃ Lat.(s) ┃ TTFT(ms) ┃ TTFT(ms) ┃ TPOT(ms) ┃ TPOT(ms) ┃ toks/s ┃    Rate┃
┡━━━━━━╇━━━━━━╇━━━━━╇━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩
│    1 │  INF │  50 │ 0.50 │   2.001 │   4.110 │    363.1 │   1716.9 │      3.3 │      4.8 │ 249.79 │  100.0%│
│    2 │  INF │  91 │ 0.78 │   2.535 │   6.970 │    375.2 │   1797.7 │      4.3 │     10.5 │ 389.28 │  100.0%│
│    4 │  INF │  92 │ 0.98 │   3.792 │  10.790 │    522.6 │   6682.7 │      6.5 │     16.6 │ 487.90 │  100.0%│
│    8 │  INF │ 181 │ 1.35 │   5.524 │  17.360 │    687.8 │  11239.7 │      9.7 │     26.6 │ 672.76 │  100.0%│
│   16 │  INF │ 372 │ 1.41 │  10.519 │  47.210 │   2763.4 │  22870.4 │     15.5 │     85.6 │ 706.72 │  100.0%│
└──────┴──────┴─────┴──────┴─────────┴─────────┴──────────┴──────────┴──────────┴──────────┴────────┴────────┘

Accuracy verification

Test Plan

@yweng0828 yweng0828 force-pushed the yweng/dflash_support branch from 5c4426d to fab46ea Compare June 18, 2026 15:26
@yweng0828 yweng0828 marked this pull request as ready for review June 18, 2026 15:26
@yweng0828 yweng0828 requested a review from a team as a code owner June 18, 2026 15:26

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: fab46ea098

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +208 to +214
if self.is_draft:
self.forward_decode_metadata = MHADecodeMetadata(
page_table=page_table,
seq_lens=seq_lens,
# DFLASH drafts a whole block in one forward; the decode
# kernel assumes query length 1, so expand each request
# into spec_num_tokens rows. Give every row the SAME full
# seq_len so each block query attends over the entire block
# (non-causal block-diffusion drafting).
self._fill_spec_metadata_uniform(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate draft metadata expansion to DFlash only

When an existing EAGLE3/MTP draft uses an MHA-family backend in eager decode (for example with CUDA graphs disabled or an uncaptured batch size), cuda_graph_wrapper._init_forward_metadata calls this DECODE path for the draft backend, and Eagle._run_multi_step_decode still sends only one query row per request (input_num_tokens=bs). This new if self.is_draft branch now expands the draft metadata to bs * spec_num_tokens rows for all draft models, not just DFlash, so the decode kernel receives page tables/seq_lens sized for many more rows than the q tensor and will either fail or attend with mismatched metadata. The expansion should be conditioned on draft_block_decode; legacy draft decode should keep the unexpanded page_table/seq_lens behavior.

Useful? React with 👍 / 👎.

@lightseek-bot

Copy link
Copy Markdown
Contributor

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 894f7c3d13

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

base_ctx.num_extends,
self.spec_num_tokens,
)
return self.draft(current_tokens)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid drafting during mid-chunk prefill

When every row is still in a chunked prefill (input_buffers.all_extends_mid_chunk), the DFlash block returned here is discarded by the scheduler until the prompt finishes, but self.draft(current_tokens) still writes spec_num_tokens draft KV rows starting at the current chunk end. For long prompts where those future positions cross pages the scheduler has not allocated yet, req_to_page still contains zeroes for the next block, so the draft write can overwrite page 0 in the draft cache and poison the prompt context used once prefill completes. The EAGLE path already avoids extra draft work in this exact mid-chunk case; DFlash should similarly skip the block draft or avoid writing draft KV until the request can transition to decode.

Useful? React with 👍 / 👎.

@lightseek-bot lightseek-bot requested a review from minedec June 19, 2026 04:12
mesaleh and others added 3 commits June 22, 2026 19:30
Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
@yweng0828 yweng0828 force-pushed the yweng/dflash_support branch from 894f7c3 to ed5e0fc Compare June 23, 2026 03:27

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ed5e0fc1c1

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +384 to +387
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reduce the final DFlash MLP output before norm

In dense tensor-parallel runs (--dense-tp-size >1, including the added Kimi DFlash launch), DFlashMLP.down_proj returns an unreduced row-parallel partial because it is constructed with reduce_results=False; intermediate layers get corrected by the next layer's input all-reduce, but after the last layer there is no next layer. This final norm therefore consumes only the local shard's MLP contribution and the draft sampler proposes tokens from incorrect hidden states, so add a final all-reduce (or reduce the last down_proj) before normalizing.

Useful? React with 👍 / 👎.

@yweng0828 yweng0828 force-pushed the yweng/dflash_support branch from c23bb44 to aa31274 Compare June 23, 2026 12:19

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: aa31274cdc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +469 to +473
cache_locs = self.draft_out_cache_loc_buf[: bs * self.spec_num_tokens]
compute_out_cache_loc_uniform(
out_cache_loc_ptr=cache_locs,
req_pool_indices=req_pool_indices,
uniform_input_length=self.spec_num_tokens,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reserve enough slots for DFlash block writes

In non-overlap speculative decode, ScheduleBatch.prealloc_for_draft_decode reserves only spec_num_steps - 1 slots, and DFLASH enforces spec_num_steps == spec_num_tokens - 1; this path nevertheless computes cache locations for a full spec_num_tokens block starting at the accepted prefix. When a request accepts enough tokens near a page boundary, the tail of this block can land past the pages the scheduler reserved, so compute_out_cache_loc_uniform reads zero/unallocated req_to_page entries and writes draft KV into the wrong page. DFlash needs a matching full-block reservation (or a shorter write) before these locations are computed.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 733c23bf26

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

f"outputs. Got invalid ids {invalid}; valid range is "
f"[0, {num_layers - 2}] for {num_layers} target layers."
)
self.model.layers_to_capture = {val + 1 for val in layer_ids}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve DFlash target layer order

When a DFlash checkpoint declares dflash_config.target_layer_ids in a non-sorted order (for example [20, 4]), this conversion to a set loses the checkpoint’s requested order. DeepseekV3Model.forward appends captured auxiliary states in layer traversal order and the logits processor concatenates them in that appended order, so the tensor passed into the DFlash fc projection is sorted by layer instead of matching the checkpoint’s feature order, silently permuting the feature blocks and producing bad draft cache/tokens. Preserve the requested order or reject unsorted configs explicitly.

Useful? React with 👍 / 👎.

@mesaleh

mesaleh commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Thanks for carrying this forward. I saw this PR was created from #263 because the branch was not editable. Since it keeps my original commit/author and adds the requested test coverage, I'll track this PR and can help address remaining review/CI issues if useful.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4437969881

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +269 to +271
if ctx.forward_mode.is_idle():
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mirror DFlash draft collectives on idle ranks

When DFLASH runs with data parallelism and one DP rank has no requests while another rank is drafting, execute_idle_forward calls the draft model with ForwardMode.IDLE; this branch skips attention and the active path's norm/all-reduce sites, and the only MLP it runs uses reduce_results=False, so the idle rank never enters the NCCL collectives that active DFlash layers enter later in this function. That mismatched collective sequence can hang the worker group; the idle path should issue the same zero-token reductions as the active draft path.

Useful? React with 👍 / 👎.

language_model = getattr(target_model, "language_model", target_model)
self.target_model = target_model
self.target_language_model = language_model
self.embed_tokens = target_model.get_input_embeddings()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fall back to target model embeddings when binding DFlash

When DFLASH is enabled for a target that advertises the new set_dflash_layers_to_capture support but does not implement get_input_embeddings (for example DeepseekV3ForCausalLM in this repo exposes model.embed_tokens/get_embed_and_head instead), ModelExecutor calls bind_target_model during startup and this line raises AttributeError before serving. Resolve the embedding module through the language model or model.embed_tokens fallback so direct DeepSeek/Kimi language-model targets can actually use the DFlash path.

Useful? React with 👍 / 👎.

@lightseek-bot lightseek-bot merged commit beed887 into main Jun 25, 2026
94 of 112 checks passed
@lightseek-bot lightseek-bot deleted the yweng/dflash_support branch June 25, 2026 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants