Skip to content

feat: DeepSpec - DSpark trainer support#129

Open
Dogacel wants to merge 3 commits into
lightseekorg:mainfrom
Dogacel:dspark
Open

feat: DeepSpec - DSpark trainer support#129
Dogacel wants to merge 3 commits into
lightseekorg:mainfrom
Dogacel:dspark

Conversation

@Dogacel

@Dogacel Dogacel commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Overview

Support DeepSpec - DSpark training as released in: https://github.com/deepseek-ai/DeepSpec

DSparkTrainer is designed as a sub-class of DFlashTrainer to reduce code duplication. They share a common parameter dflash_block_size as well.

The file torchspec/models/dspark.py consists the forward pass and loss function definition and it is mostly vibe-coded. Other files are created with AI-assistance, however I have a much higher confidence about their correctness.

Testing

A full-training is not run, however a 500-step validation run is concluded using Qwen3-8B model using ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_dspark.yaml command. Also no inference engine supports DSpark at the moment, so the trainer code might require updates as vLLM and SGLang supports DSpark.

The training run is OK but doesn't look great, I think more validation on the loss / forward pass is needed.

image image image

Dogacel added 2 commits June 27, 2026 21:07
Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>
Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e1ae37cb7c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# SOFTWARE.

from torchspec.models.dflash import DFlashModel
from torchspec.models.dspark import DSparkModel

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Add the missing DSparkModel module

This new import makes torchspec.models (and therefore top-level torchspec) depend on torchspec.models.dspark, but the reviewed commit never adds that module; git ls-tree -r 6e22b865f35fa6738545abe20811bc70e0b6b364 | rg 'models/dspark|dspark.py' only finds torchspec/models/draft/dspark.py. In any environment with the normal dependencies installed, importing the package will raise ModuleNotFoundError: No module named 'torchspec.models.dspark', blocking existing Eagle3/DFlash users as well as the new DSpark trainer.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6a62189b5d

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".


last_hidden_states = batch.get("last_hidden_states", None)
if last_hidden_states is not None:
last_hidden_states = last_hidden_states.to(device, non_blocking=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize vLLM last hidden states before LM head

When inference_engine_type=vllm (which this DSpark path allows), the engine captures the final hidden state before the model's final norm and config_to_flat_args marks last_hidden_states_prenorm=True; Eagle3 handles this by applying the verifier norm before using the LM head. Here the raw tensor is passed into DSparkModel, where the L1 and confidence targets are built with F.linear(last_hidden_states, lm_head_weight), so DSpark training with vLLM uses a misnormalized target distribution. Load and apply the target norm when last_hidden_states_prenorm is set.

Useful? React with 👍 / 👎.

@Dogacel

Dogacel commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator Author

To validate, I've let Claude compare the forward/backward outputs of our model to DeepSpec's implementation.

Still not 100% sure but I think it is a strong signal that the implementation is adapted correctly from DeepSeek's repository.

"""Exhaustive white-box validation of models/dspark.py:DSparkModel.forward.

For each config: run the REAL forward, re-derive its (deterministic) backbone
output, then INDEPENDENTLY rebuild the loss/metrics from that output using
DeepSpec's compute_dspark_loss / VanillaMarkov / AcceptRatePredictor / build_eval_mask
and the spec'd indexing. Any mis-wiring (label offset, which logits feed
markov/L1, aligned-target index, confidence features, alpha/denominator
aggregation, metric masks) makes forward and reference diverge.

Sweeps: block_size (TTT analog), batch/seq, markov on/off, confidence on/off &
with/without markov fusion, loss-alpha mixes, decay on/off, loss-mask layouts.
"""
import os, sys, itertools
sys.path.insert(0, "/tmp/claude-1012/-storage-dogac-TorchSpec/b24e5b98-f597-4388-b2da-cb3ee04cfc18/scratchpad/DeepSpec")
os.environ.setdefault("MASTER_ADDR", "127.0.0.1"); os.environ.setdefault("MASTER_PORT", "29572")

import torch
import torch.nn.functional as F
import torch.distributed as dist
if not dist.is_initialized():
    dist.init_process_group("gloo", rank=0, world_size=1)

from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel
from torchspec.models.dspark import DSparkModel
from deepspec.modeling.dspark.markov_head import VanillaMarkov as DS_Markov
from deepspec.modeling.dspark.common import AcceptRatePredictor as DS_Conf, DSparkForwardOutput, build_eval_mask
from deepspec.modeling.dspark.loss import compute_dspark_loss


def make_mask(B, S, mode):
    m = torch.ones(B, S)
    m[:, : max(1, S // 6)] = 0  # prompt
    if mode == "gaps":
        m[0, S // 2 : S // 2 + 2] = 0      # interior gap -> tests cumprod truncation
        if B > 1:
            m[1, S - 4 :] = 0
    elif mode == "all":
        m[:] = 0
    return m


def run_config(c, seed):
    torch.manual_seed(seed)
    H, V, R, NTL, Bk = c["H"], c["V"], c["R"], c["NTL"], c["block_size"]
    cfg = DSparkConfig(hidden_size=H, intermediate_size=2 * H, num_hidden_layers=2,
                       num_attention_heads=4, num_key_value_heads=2, vocab_size=V,
                       num_target_layers=NTL, target_hidden_size=H, target_num_hidden_layers=8,
                       target_layer_ids=[1, 5][:NTL] if NTL <= 2 else None, mask_token_id=3,
                       markov_rank=R, enable_confidence_head=c["enable_conf"],
                       confidence_head_with_markov=c["with_markov"])
    draft = DSparkDraftModel(cfg); draft.freeze_embedding(); draft.eval()
    model = DSparkModel(draft, block_size=Bk, num_anchors=c["num_anchors"],
                        loss_decay_gamma=c["gamma"], ce_loss_alpha=c["ce"],
                        l1_loss_alpha=c["l1"], confidence_head_alpha=c["conf"])

    ds_markov = None
    if draft.markov_head is not None:
        ds_markov = DS_Markov(vocab_size=V, markov_rank=R); ds_markov.load_state_dict(draft.markov_head.state_dict())
    ds_conf = None
    if draft.confidence_head is not None:
        ci = H + (R if c["with_markov"] else 0)
        ds_conf = DS_Conf(input_dim=ci); ds_conf.load_state_dict(draft.confidence_head.state_dict())

    B, S = c["B"], c["S"]
    input_ids = torch.randint(0, V, (B, S))
    hsl = [torch.randn(B, S, H) for _ in range(NTL)]
    loss_mask = make_mask(B, S, c["mask"])
    lm = torch.randn(V, H)
    lhs = torch.randn(B, S, H)

    rng = torch.get_rng_state()
    out_fwd = model(input_ids=input_ids, hidden_states_list=hsl, loss_mask=loss_mask,
                    lm_head_weight=lm, last_hidden_states=lhs)
    loss_fwd, acc_fwd, lpp_fwd, app_fwd, cpp_fwd, comps_fwd = out_fwd

    # re-derive the SAME backbone output deterministically
    torch.set_rng_state(rng)
    ctx_feat = model.draft_model.extract_context_feature(hsl)
    anchors, keep = model._sample_anchor_positions(S, loss_mask, input_ids.device)
    noise = model._create_noise_embed(input_ids, anchors, keep)
    cpos, dpos = model._create_position_ids(anchors, S)
    dh = model.draft_model(draft_input_ids=None, context_feature=ctx_feat, draft_position_ids=dpos,
                           context_position_ids=cpos, block_mask=None, noise_embedding=noise)
    nb = anchors.shape[1]; dh4 = dh.view(B, nb, Bk, H)

    # ---- independent reference (DeepSpec primitives + spec indexing) ----
    base = F.linear(dh, lm).view(B, nb, Bk, V)
    label_idx = anchors.unsqueeze(-1) + torch.arange(1, Bk + 1).view(1, 1, -1)
    safe = label_idx.clamp(max=S - 1)
    safe = torch.where(keep.unsqueeze(-1), safe, torch.zeros_like(safe))
    target_ids = torch.gather(input_ids.unsqueeze(1).expand(-1, nb, -1), 2, safe)
    eval_mask = build_eval_mask(seq_len=S, loss_mask=loss_mask, label_indices=label_idx,
                                safe_label_indices=safe, block_keep_mask=keep)
    anchor_tok = torch.gather(input_ids, 1, anchors)
    prev = torch.cat([anchor_tok.unsqueeze(-1), target_ids[:, :, :-1]], dim=-1)
    draft_logits = ds_markov.apply_block_logits(base, token_ids=prev, hidden_states=dh4) if ds_markov is not None else base
    tgt_idx = (safe - 1).clamp(min=0)
    aligned_h = torch.gather(lhs.unsqueeze(1).expand(-1, nb, -1, -1), 2,
                             tgt_idx.unsqueeze(-1).expand(-1, -1, -1, H))
    aligned_logits = F.linear(aligned_h, lm)
    conf_pred = None
    if ds_conf is not None and c["conf"] > 0:
        feat = torch.cat([dh4, ds_markov.get_prev_embeddings(prev)], -1) if c["with_markov"] else dh4
        conf_pred = ds_conf(feat).float()
    out = DSparkForwardOutput(draft_logits=draft_logits, target_ids=target_ids, eval_mask=eval_mask,
                              block_keep_mask=keep, confidence_pred=conf_pred, aligned_target_logits=aligned_logits)
    ref_loss = compute_dspark_loss(outputs=out, loss_decay_gamma=c["gamma"], ce_loss_alpha=c["ce"],
                                   l1_loss_alpha=c["l1"], confidence_head_alpha=c["conf"])

    # ---- independent metrics (spec recompute) ----
    em = eval_mask.float()
    pred = draft_logits.argmax(-1)
    correct = (pred == target_ids) & (em > 0.5)
    cpp_ref = em.sum(dim=(0, 1))
    cpp_c = cpp_ref.clamp(min=1.0)
    ce_pt = F.cross_entropy(draft_logits.reshape(-1, V), target_ids.reshape(-1), reduction="none").view(B, nb, Bk)
    lpp_ref = (ce_pt * em).sum(dim=(0, 1)) / cpp_c
    app_ref = correct.float().sum(dim=(0, 1)) / cpp_c

    # ---- comparisons ----
    res = {}
    res["loss"] = torch.allclose(loss_fwd, ref_loss, atol=1e-4, rtol=1e-4)
    res["lpp"] = torch.allclose(lpp_fwd, lpp_ref, atol=1e-4, rtol=1e-4)
    res["app"] = torch.allclose(app_fwd, app_ref, atol=1e-4, rtol=1e-4)
    res["cpp"] = torch.equal(cpp_fwd, cpp_ref)
    # internal identity: combined loss == sum alpha*component (single process)
    recombined = c["ce"] * comps_fwd["ce_loss"] + c["l1"] * comps_fwd["l1_loss"] + c["conf"] * comps_fwd["confidence_loss"]
    res["components_identity"] = torch.allclose(loss_fwd.detach(), recombined, atol=1e-4, rtol=1e-4)
    return res, (loss_fwd.item(), ref_loss.item())


# ---- config grid ----
HEADS = {
    "full(mk+cf+wmk)": dict(R=16, enable_conf=True, with_markov=True),
    "dflash-like(no mk,no cf)": dict(R=0, enable_conf=False, with_markov=False),
    "markov-only(no cf)": dict(R=16, enable_conf=False, with_markov=False),
    "cf-no-markov(R=0)": dict(R=0, enable_conf=True, with_markov=False),
    "cf-no-fusion(R=16)": dict(R=16, enable_conf=True, with_markov=False),
}
ALPHAS = {
    "ce.1/l1.9/cf1": dict(ce=0.1, l1=0.9, conf=1.0),
    "ce1/l10/cf0": dict(ce=1.0, l1=0.0, conf=0.0),
    "ce0/l11/cf0": dict(ce=0.0, l1=1.0, conf=0.0),
}
SHAPES = [dict(B=1, S=16), dict(B=2, S=28), dict(B=3, S=40), dict(B=2, S=12)]
BLOCKS = [1, 2, 4, 7, 8, 16]

configs = []
# core sweep: every head x every alpha (block 7, shape (2,28), normal mask)
for hn, h in HEADS.items():
    for an, a in ALPHAS.items():
        if not h["enable_conf"] and a["conf"] > 0:  # conf alpha needs a conf head
            continue
        if h["enable_conf"] and a["conf"] == 0:      # skip building a head we won't use
            continue
        configs.append(dict(name=f"{hn} | {an}", H=64, V=96, NTL=2, num_anchors=10,
                            gamma=4.0, mask="normal", block_size=7, **SHAPES[1], **h, **a))
# block_size sweep (full heads)
for bk in BLOCKS:
    configs.append(dict(name=f"block_size={bk} (full)", H=64, V=96, NTL=2, num_anchors=8, gamma=4.0,
                        mask="normal", block_size=bk, **SHAPES[1], **HEADS["full(mk+cf+wmk)"], **ALPHAS["ce.1/l1.9/cf1"]))
# shape sweep (full heads, incl short seq -> dummy anchors)
for sh in SHAPES:
    configs.append(dict(name=f"shape B={sh['B']} S={sh['S']} (full)", H=64, V=96, NTL=2, num_anchors=8,
                        gamma=4.0, mask="normal", block_size=4, **sh, **HEADS["full(mk+cf+wmk)"], **ALPHAS["ce.1/l1.9/cf1"]))
# mask layout sweep (full heads)
for mk in ["normal", "gaps", "all"]:
    configs.append(dict(name=f"mask={mk} (full)", H=64, V=96, NTL=2, num_anchors=10, gamma=4.0,
                        mask=mk, block_size=7, **SHAPES[1], **HEADS["full(mk+cf+wmk)"], **ALPHAS["ce.1/l1.9/cf1"]))
# decay off + bigger vocab + NTL=3
configs.append(dict(name="gamma=0 (no decay)", H=64, V=96, NTL=2, num_anchors=10, gamma=0.0,
                    mask="normal", block_size=7, **SHAPES[1], **HEADS["full(mk+cf+wmk)"], **ALPHAS["ce.1/l1.9/cf1"]))
configs.append(dict(name="bigV=400,NTL=3,H=48", H=48, V=400, NTL=3, num_anchors=8, gamma=4.0,
                    mask="normal", block_size=7, **SHAPES[1], **HEADS["full(mk+cf+wmk)"], **ALPHAS["ce.1/l1.9/cf1"]))

# ---- run ----
total = 0; passed = 0; fails = []
for c in configs:
    for seed in (0, 1):
        total += 1
        try:
            res, (lf, lr) = run_config(c, seed)
            ok = all(res.values())
        except Exception as e:
            ok = False; res = {"EXC": str(e)[:80]}
        passed += ok
        if not ok:
            fails.append((c["name"], seed, res))
        flag = "PASS" if ok else "FAIL"
        bad = "" if ok else "  <-- " + ",".join(k for k, v in res.items() if v is not True)
        print(f"[{flag}] seed{seed}  {c['name']}{bad}")

print(f"\n==== {passed}/{total} config×seed checks passed ====")
if fails:
    print("FAILURES:")
    for n, s, r in fails:
        print(f"  {n} (seed {s}): {r}")
dist.destroy_process_group()

Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>
@Dogacel Dogacel requested a review from yubofredwang June 27, 2026 22:07
@Dogacel

Dogacel commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator Author
image 50K samples on perfectblend for 3 epochs

@Dogacel Dogacel requested a review from lightseek-bot June 28, 2026 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant