Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions configs/sglang_qwen3_8b_dspark.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# DSpark training config for Qwen3-8B target model
#
# DSpark = DFlash block-diffusion drafter + EAGLE-style Markov & confidence
# heads, trained with cross-entropy + L1 distribution distillation + confidence
# BCE. The L1 / confidence terms need the target's final hidden state, so
# inference.store_last_hidden_states MUST be true (DFlash leaves it false).
#
# GPU allocation (8x GPU):
# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode)
# - 4 GPUs for training (FSDP FULL_SHARD)
#
# Usage:
# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dspark.yaml
# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_dspark.yaml

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true
draft_model_config: torchspec/config/dspark_draft_config.json

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: null
eval_interval: 100
chat_template: qwen
prompt_key: conversations
min_loss_tokens: 32

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 2
learning_rate: 6e-4
min_lr: 6e-5
weight_decay: 0.0
max_concurrent_batches: 1
max_grad_norm: 1.0
max_seq_length: 2048
num_epochs: 3
seed: 42
training_num_gpus_per_node: 4
training_num_nodes: 1
ttt_length: 7
fsdp_strategy: FULL_SHARD
fsdp_reduce_dtype: bfloat16
prefetch_depth: 8
save_interval: 1000
save_per_epoch: true
max_checkpoints: 2
warmup_ratio: 0.04

# DSpark-specific parameters
dflash_block_size: 7
dspark_num_anchors: 512
dspark_num_target_layers: 5
dspark_loss_decay_gamma: 4.0
dspark_ce_loss_alpha: 0.1
dspark_l1_loss_alpha: 0.9
dspark_confidence_head_alpha: 1.0

inference:
inference_engine_type: sgl
store_last_hidden_states: true
inference_num_gpus: 4
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 4
max_sample_pool_size: 64
inference_buffer_threshold: 32
inference_batch_size: 8
sglang:
tp_size: 1
mem_fraction_static: 0.7

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB
enable_hard_pin: true

output_dir: ./outputs/qwen3-8b-dspark
cache_dir: ./cache/qwen3-8b-dspark
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
255 changes: 255 additions & 0 deletions tests/test_dspark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# Copyright (c) 2026 LightSeek Foundation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Tests for DSpark (DFlash backbone + Markov / confidence heads + L1 distillation).

Pins the DSpark wiring so future refactors can't silently break the objective:

1. DSparkConfig / DSparkDraftModel: head construction, subclass relationship.
2. forward returns the 6-tuple with detached per-component losses.
3. Loss-wiring invariants (no DeepSpec dependency):
- internal identity: combined loss == ce_a*ce + l1_a*l1 + cf_a*conf (so the
logged loss_components are trustworthy)
- all-masked batch -> loss 0
- gradients reach markov + confidence + backbone; embedding stays frozen
- next-token convention: every within-block slot is supervised (B predictions)
4. Markov / confidence head unit math.
5. Algorithm dispatch (DSparkConfig resolves from the JSON and is checked before
DFlashConfig since it subclasses it).
"""

import unittest

import torch

from torchspec.models.draft.auto import AutoDraftModelConfig
from torchspec.models.draft.dflash import DFlashConfig
from torchspec.models.draft.dspark import (
AcceptRatePredictor,
DSparkConfig,
DSparkDraftModel,
VanillaMarkov,
)
from torchspec.models.dspark import DSparkModel

CE_A, L1_A, CF_A = 0.1, 0.9, 1.0


def _make_dspark_config(
H=64,
V=128,
num_target_layers=2,
markov_rank=16,
enable_confidence_head=True,
confidence_head_with_markov=True,
):
return DSparkConfig(
hidden_size=H,
intermediate_size=256,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=2,
vocab_size=V,
rms_norm_eps=1e-6,
max_position_embeddings=512,
rope_theta=10000.0,
num_target_layers=num_target_layers,
target_hidden_size=H,
target_num_hidden_layers=12,
mask_token_id=V - 1,
markov_rank=markov_rank,
markov_head_type="vanilla",
enable_confidence_head=enable_confidence_head,
confidence_head_with_markov=confidence_head_with_markov,
)


def _make_dspark_model(block_size=4, num_anchors=6, **cfg_kw):
config = _make_dspark_config(**cfg_kw)
draft = DSparkDraftModel(config).to(dtype=torch.float32)
draft.freeze_embedding()
return DSparkModel(
draft_model=draft,
block_size=block_size,
num_anchors=num_anchors,
loss_decay_gamma=4.0,
ce_loss_alpha=CE_A,
l1_loss_alpha=L1_A,
confidence_head_alpha=CF_A,
)


def _batch(B=2, S=24, H=64, V=128, num_target_layers=2, all_masked=False, seed=0):
g = torch.Generator().manual_seed(seed)
input_ids = torch.randint(0, V, (B, S), generator=g)
hidden_states_list = [torch.randn(B, S, H, generator=g) for _ in range(num_target_layers)]
loss_mask = torch.zeros(B, S) if all_masked else torch.ones(B, S)
if not all_masked:
loss_mask[:, :2] = 0 # prompt
lm_head_weight = torch.randn(V, H, generator=g)
last_hidden_states = torch.randn(B, S, H, generator=g)
return dict(
input_ids=input_ids,
hidden_states_list=hidden_states_list,
loss_mask=loss_mask,
lm_head_weight=lm_head_weight,
last_hidden_states=last_hidden_states,
)


class TestDSparkConfig(unittest.TestCase):
def test_subclasses_dflash_and_attrs(self):
cfg = _make_dspark_config(markov_rank=32)
self.assertIsInstance(cfg, DFlashConfig) # ordering hazard: check DSpark first
self.assertEqual(cfg.model_type, "dspark")
self.assertEqual(cfg.markov_rank, 32)
self.assertTrue(cfg.enable_confidence_head)

def test_draft_model_heads(self):
cfg = _make_dspark_config(H=64, markov_rank=16)
m = DSparkDraftModel(cfg)
self.assertIsInstance(m.markov_head, VanillaMarkov)
self.assertIsInstance(m.confidence_head, AcceptRatePredictor)
# confidence input = hidden + markov_rank when fused
self.assertEqual(m.confidence_head.proj.in_features, 64 + 16)

def test_no_heads(self):
cfg = _make_dspark_config(
markov_rank=0, enable_confidence_head=False, confidence_head_with_markov=False
)
m = DSparkDraftModel(cfg)
self.assertIsNone(m.markov_head)
self.assertIsNone(m.confidence_head)


class TestDSparkForward(unittest.TestCase):
def test_returns_six_tuple_with_detached_components(self):
m = _make_dspark_model()
out = m(**_batch())
self.assertEqual(len(out), 6)
loss, acc, lpp, app, cpp, comps = out
self.assertEqual(set(comps), {"ce_loss", "l1_loss", "confidence_loss"})
for v in comps.values():
self.assertTrue(torch.isfinite(v).all())
self.assertFalse(v.requires_grad) # detached for logging
self.assertTrue(torch.isfinite(loss))
self.assertEqual(lpp.shape[0], m.block_size)

def test_internal_loss_identity(self):
# At world_size==1 (no process group), the combined loss must equal the
# alpha-weighted sum of the logged components — so the components are a
# faithful decomposition of what's actually optimized.
m = _make_dspark_model()
loss, _, _, _, _, comps = m(**_batch(seed=1))
recomputed = (
CE_A * comps["ce_loss"] + L1_A * comps["l1_loss"] + CF_A * comps["confidence_loss"]
)
self.assertTrue(
torch.allclose(loss, recomputed, atol=1e-4), f"{loss.item()} vs {recomputed.item()}"
)

def test_all_masked_is_zero(self):
m = _make_dspark_model()
loss, _, _, _, _, comps = m(**_batch(all_masked=True))
self.assertAlmostEqual(loss.item(), 0.0, places=5)
for v in comps.values():
self.assertAlmostEqual(v.item(), 0.0, places=5)

def test_next_token_convention_all_slots_supervised(self):
# Fix 1: every within-block slot predicts a real token (B predictions),
# unlike DFlash where slot 0 is the masked anchor. With a long fully
# supervised sequence, every position should accumulate supervised tokens.
m = _make_dspark_model(block_size=4, num_anchors=8)
b = _batch(B=2, S=40)
b["loss_mask"] = torch.ones(2, 40)
_, _, _, _, count_per_position, _ = m(**b)
self.assertEqual(count_per_position.shape[0], 4)
self.assertTrue(
(count_per_position > 0).all(), f"some slot unsupervised: {count_per_position.tolist()}"
)

def test_grad_flow_and_frozen_embedding(self):
m = _make_dspark_model()
loss, *_ = m(**_batch(seed=2))
loss.backward()
draft = m.draft_model
self.assertIsNotNone(draft.markov_head.markov_w2.weight.grad)
self.assertGreater(draft.markov_head.markov_w2.weight.grad.abs().sum().item(), 0)
self.assertIsNotNone(draft.confidence_head.proj.weight.grad)
self.assertGreater(draft.confidence_head.proj.weight.grad.abs().sum().item(), 0)
self.assertIsNotNone(draft.context_proj.weight.grad)
self.assertIsNone(draft.embed_tokens.weight.grad) # frozen

def test_ce_only_without_target(self):
# ce-only (l1=0, no confidence) must run without last_hidden_states.
m = _make_dspark_model(
markov_rank=16, enable_confidence_head=False, confidence_head_with_markov=False
)
m.l1_loss_alpha = 0.0
m.ce_loss_alpha = 1.0
m.confidence_head_alpha = 0.0
b = _batch()
b["last_hidden_states"] = None
loss, *_ = m(**b)
self.assertTrue(torch.isfinite(loss))


class TestHeadMath(unittest.TestCase):
def test_vanilla_markov_is_bigram_bias(self):
torch.manual_seed(0)
mk = VanillaMarkov(vocab_size=50, markov_rank=8)
base = torch.randn(2, 3, 4, 50)
prev = torch.randint(0, 50, (2, 3, 4))
out = mk.apply_block_logits(base, token_ids=prev)
expected = base + mk.markov_w2(mk.markov_w1(prev))
self.assertTrue(torch.allclose(out, expected, atol=1e-6))

def test_confidence_head_is_linear(self):
torch.manual_seed(0)
head = AcceptRatePredictor(20)
feats = torch.randn(2, 3, 4, 20)
out = head(feats)
expected = head.proj(feats).squeeze(-1)
self.assertTrue(torch.allclose(out, expected, atol=1e-6))
self.assertEqual(out.shape, (2, 3, 4))


class TestDispatch(unittest.TestCase):
def test_json_resolves_to_dspark_config(self):
cfg = AutoDraftModelConfig.from_dict(
{
"architectures": ["DSparkDraftModel"],
"model_type": "dspark",
"hidden_size": 64,
"vocab_size": 128,
"num_hidden_layers": 1,
"num_target_layers": 2,
"markov_rank": 16,
"enable_confidence_head": True,
}
)
self.assertIsInstance(cfg, DSparkConfig)
# Subclass of DFlashConfig -> any isinstance(DFlashConfig) dispatch must
# test DSparkConfig first (trainer_actor / train_entry rely on this).
self.assertIsInstance(cfg, DFlashConfig)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions torchspec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
from torchspec.models.dflash import DFlashModel
from torchspec.models.draft import AutoDraftModelConfig, AutoEagle3DraftModel
from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel
from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel
from torchspec.models.dspark import DSparkModel

__all__ = [
"Eagle3Model",
"DFlashModel",
"DFlashConfig",
"DFlashDraftModel",
"DSparkModel",
"DSparkConfig",
"DSparkDraftModel",
"AutoDraftModelConfig",
"AutoEagle3DraftModel",
]
23 changes: 23 additions & 0 deletions torchspec/config/dspark_draft_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"architectures": ["DSparkDraftModel"],
"model_type": "dspark",
"hidden_size": 4096,
"intermediate_size": 12288,
"num_hidden_layers": 5,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"vocab_size": 151936,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 40960,
"rope_theta": 1000000.0,
"num_target_layers": 5,
"target_hidden_size": 4096,
"target_num_hidden_layers": 36,
"target_layer_ids": [1, 9, 17, 25, 33],
"mask_token_id": 151669,
"markov_rank": 256,
"markov_head_type": "vanilla",
"enable_confidence_head": true,
"confidence_head_with_markov": true,
"tie_word_embeddings": false
}
Loading
Loading