From e1ae37cb7c0f935d9e915de8e18f80f8ec048818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Sat, 27 Jun 2026 21:07:59 +0000 Subject: [PATCH 1/4] feat: DeepSpec - DSpark trainer support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Doğaç Eldenk --- configs/sglang_qwen3_8b_dspark.yaml | 89 +++++++++++++ torchspec/__init__.py | 5 + torchspec/config/dspark_draft_config.json | 23 ++++ torchspec/config/train_config.py | 8 ++ torchspec/models/__init__.py | 2 + torchspec/models/draft/__init__.py | 3 + torchspec/models/draft/auto.py | 3 + torchspec/models/draft/dspark.py | 154 ++++++++++++++++++++++ torchspec/train_entry.py | 9 +- torchspec/training/dflash_trainer.py | 59 ++++++--- torchspec/training/dspark_trainer.py | 104 +++++++++++++++ torchspec/training/trainer_actor.py | 13 +- 12 files changed, 448 insertions(+), 24 deletions(-) create mode 100644 configs/sglang_qwen3_8b_dspark.yaml create mode 100644 torchspec/config/dspark_draft_config.json create mode 100644 torchspec/models/draft/dspark.py create mode 100644 torchspec/training/dspark_trainer.py diff --git a/configs/sglang_qwen3_8b_dspark.yaml b/configs/sglang_qwen3_8b_dspark.yaml new file mode 100644 index 0000000..136bc27 --- /dev/null +++ b/configs/sglang_qwen3_8b_dspark.yaml @@ -0,0 +1,89 @@ +# DSpark training config for Qwen3-8B target model +# +# DSpark = DFlash block-diffusion drafter + EAGLE-style Markov & confidence +# heads, trained with cross-entropy + L1 distribution distillation + confidence +# BCE. The L1 / confidence terms need the target's final hidden state, so +# inference.store_last_hidden_states MUST be true (DFlash leaves it false). +# +# GPU allocation (8x GPU): +# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode) +# - 4 GPUs for training (FSDP FULL_SHARD) +# +# Usage: +# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dspark.yaml +# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_dspark.yaml + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + draft_model_config: torchspec/config/dspark_draft_config.json + +dataset: + train_data_path: ../examples/data/sample_conversations.jsonl + eval_data_path: null + eval_interval: 100 + chat_template: qwen + prompt_key: conversations + min_loss_tokens: 32 + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 2 + learning_rate: 6e-4 + min_lr: 6e-5 + weight_decay: 0.0 + max_concurrent_batches: 1 + max_grad_norm: 1.0 + max_seq_length: 2048 + num_epochs: 3 + seed: 42 + training_num_gpus_per_node: 4 + training_num_nodes: 1 + ttt_length: 7 + fsdp_strategy: FULL_SHARD + fsdp_reduce_dtype: bfloat16 + prefetch_depth: 8 + save_interval: 1000 + save_per_epoch: true + max_checkpoints: 2 + warmup_ratio: 0.04 + + # DSpark-specific parameters + dflash_block_size: 7 + dspark_num_anchors: 512 + dspark_num_target_layers: 5 + dspark_loss_decay_gamma: 4.0 + dspark_ce_loss_alpha: 0.1 + dspark_l1_loss_alpha: 0.9 + dspark_confidence_head_alpha: 1.0 + +inference: + inference_engine_type: sgl + store_last_hidden_states: true + inference_num_gpus: 4 + inference_num_gpus_per_engine: 1 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + sglang: + tp_size: 1 + mem_fraction_static: 0.7 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + enable_hard_pin: true + +output_dir: ./outputs/qwen3-8b-dspark +cache_dir: ./cache/qwen3-8b-dspark +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false diff --git a/torchspec/__init__.py b/torchspec/__init__.py index 6efef6c..05f90cd 100644 --- a/torchspec/__init__.py +++ b/torchspec/__init__.py @@ -24,12 +24,17 @@ from torchspec.models.dflash import DFlashModel from torchspec.models.draft import AutoDraftModelConfig, AutoEagle3DraftModel from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel +from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel +from torchspec.models.dspark import DSparkModel __all__ = [ "Eagle3Model", "DFlashModel", "DFlashConfig", "DFlashDraftModel", + "DSparkModel", + "DSparkConfig", + "DSparkDraftModel", "AutoDraftModelConfig", "AutoEagle3DraftModel", ] diff --git a/torchspec/config/dspark_draft_config.json b/torchspec/config/dspark_draft_config.json new file mode 100644 index 0000000..c091754 --- /dev/null +++ b/torchspec/config/dspark_draft_config.json @@ -0,0 +1,23 @@ +{ + "architectures": ["DSparkDraftModel"], + "model_type": "dspark", + "hidden_size": 4096, + "intermediate_size": 12288, + "num_hidden_layers": 5, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "vocab_size": 151936, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 40960, + "rope_theta": 1000000.0, + "num_target_layers": 5, + "target_hidden_size": 4096, + "target_num_hidden_layers": 36, + "target_layer_ids": [1, 9, 17, 25, 33], + "mask_token_id": 151669, + "markov_rank": 256, + "markov_head_type": "vanilla", + "enable_confidence_head": true, + "confidence_head_with_markov": true, + "tie_word_embeddings": false +} diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 06a5bd1..53b1d23 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -154,6 +154,14 @@ class TrainingConfig: dflash_num_anchors: int = 512 dflash_num_target_layers: int = 5 + # DSpark-specific parameters (used by DSpark trainer only) + dspark_num_anchors: int = 512 + dspark_num_target_layers: int = 5 + dspark_loss_decay_gamma: float = 4.0 + dspark_ce_loss_alpha: float = 0.1 + dspark_l1_loss_alpha: float = 0.9 + dspark_confidence_head_alpha: float = 1.0 + @dataclass class DecodeConfig: diff --git a/torchspec/models/__init__.py b/torchspec/models/__init__.py index 3fedf32..0d4316e 100644 --- a/torchspec/models/__init__.py +++ b/torchspec/models/__init__.py @@ -19,6 +19,7 @@ # SOFTWARE. from torchspec.models.dflash import DFlashModel +from torchspec.models.dspark import DSparkModel from torchspec.models.eagle3 import ( Eagle3Model, compute_lazy_target_padded, @@ -30,6 +31,7 @@ __all__ = [ "Eagle3Model", "DFlashModel", + "DSparkModel", "compute_lazy_target_padded", "compute_target_p_padded", "compiled_forward_kl_loss", diff --git a/torchspec/models/draft/__init__.py b/torchspec/models/draft/__init__.py index 0357645..daa95b6 100644 --- a/torchspec/models/draft/__init__.py +++ b/torchspec/models/draft/__init__.py @@ -22,6 +22,7 @@ from torchspec.models.draft.base import Eagle3DraftModel from torchspec.models.draft.deepseek_eagle import Eagle3DeepseekV2ForCausalLM from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel +from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel from torchspec.models.draft.llama3_eagle import LlamaForCausalLMEagle3 __all__ = [ @@ -31,5 +32,7 @@ "Eagle3DraftModel", "DFlashConfig", "DFlashDraftModel", + "DSparkConfig", + "DSparkDraftModel", "LlamaForCausalLMEagle3", ] diff --git a/torchspec/models/draft/auto.py b/torchspec/models/draft/auto.py index 9f765ba..b0dbe05 100644 --- a/torchspec/models/draft/auto.py +++ b/torchspec/models/draft/auto.py @@ -28,6 +28,7 @@ from torchspec.models.draft.deepseek_eagle import Eagle3DeepseekV2ForCausalLM from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel +from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel from torchspec.models.draft.llama3_eagle import LlamaForCausalLMEagle3 from torchspec.utils.logging import logger @@ -37,6 +38,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase): LlamaConfig: LlamaForCausalLMEagle3, DeepseekV3Config: Eagle3DeepseekV2ForCausalLM, DFlashConfig: DFlashDraftModel, + DSparkConfig: DSparkDraftModel, } @classmethod @@ -77,6 +79,7 @@ class AutoDraftModelConfig: "LlamaForCausalLMEagle3": LlamaConfig, "Eagle3DeepseekV2ForCausalLM": DeepseekV3Config, "DFlashDraftModel": DFlashConfig, + "DSparkDraftModel": DSparkConfig, } @classmethod diff --git a/torchspec/models/draft/dspark.py b/torchspec/models/draft/dspark.py new file mode 100644 index 0000000..429a667 --- /dev/null +++ b/torchspec/models/draft/dspark.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +DSpark draft model: DFlash backbone + EAGLE-style Markov and confidence heads. + +DSpark shares DFlash's block-diffusion drafter (dual-source KV injection, anchor +sampling, MASK-token noise stream) and adds two heads on top: + + - Markov head: a low-rank learned bigram bias added to the draft logits, + conditioned on the (teacher-forced) previous token. Improves the per-token + distribution without touching the backbone. + - Confidence head (AcceptRatePredictor): predicts a per-draft-position + acceptance probability, trained against the empirical draft-vs-target + accept rate (used at inference time for adaptive block length). + +Markov / confidence modeling code is adapted from DeepSeek's DeepSpec +(deepspec/modeling/dspark/{markov_head,common}.py, MIT License). +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel + + +class DSparkConfig(DFlashConfig): + """ + Configuration for the DSpark draft model. Extends :class:`DFlashConfig`. + """ + + model_type = "dspark" + + def __init__( + self, + markov_rank: int = 256, + markov_head_type: str = "vanilla", + enable_confidence_head: bool = True, + confidence_head_with_markov: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.markov_rank = markov_rank + self.markov_head_type = markov_head_type + self.enable_confidence_head = enable_confidence_head + self.confidence_head_with_markov = confidence_head_with_markov + + +class VanillaMarkov(nn.Module): + """ + Adapted from DeepSpec's ``deepspec/modeling/dspark/markov_head.py``. + """ + + def __init__(self, *, vocab_size: int, markov_rank: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.markov_rank = int(markov_rank) + self.markov_head_type = "vanilla" + assert self.markov_rank > 0, ( + f"VanillaMarkov requires markov_rank > 0, got {self.markov_rank}." + ) + self.markov_w1 = nn.Embedding(self.vocab_size, self.markov_rank) + self.markov_w2 = nn.Linear(self.markov_rank, self.vocab_size, bias=False) + + def get_prev_embeddings(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.markov_w1(token_ids.long()) + + def project_bias(self, latent_states: torch.Tensor) -> torch.Tensor: + return self.markov_w2(latent_states) + + def compute_step_bias(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.project_bias(self.get_prev_embeddings(token_ids)) + + def apply_block_logits( + self, + base_logits: torch.Tensor, + *, + token_ids: torch.Tensor, + ) -> torch.Tensor: + if base_logits.size(2) == 0: + return base_logits + return base_logits + self.compute_step_bias(token_ids) + + +class AcceptRatePredictor(nn.Module): + """ + Adapted from DeepSpec's ``deepspec/modeling/dspark/common.py``. + """ + + def __init__(self, input_dim: int): + super().__init__() + self.proj = nn.Linear(int(input_dim), 1) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.proj(features).squeeze(-1) + + +def build_markov_head(config) -> Optional[nn.Module]: + markov_rank = int(getattr(config, "markov_rank", 0)) + assert markov_rank >= 0, f"markov_rank must be >= 0, got {markov_rank}" + if markov_rank == 0: + return None + + markov_head_type = str(getattr(config, "markov_head_type", "vanilla")).lower() + if markov_head_type == "vanilla": + return VanillaMarkov(vocab_size=config.vocab_size, markov_rank=markov_rank) + raise NotImplementedError( + f"markov_head_type={markov_head_type!r} is not supported yet; only 'vanilla' " + "is implemented in TorchSpec as it is recommended by the authors." + ) + + +class DSparkDraftModel(DFlashDraftModel): + config_class = DSparkConfig + + def __init__(self, config: DSparkConfig): + super().__init__(config) + + self.markov_rank = int(getattr(config, "markov_rank", 0)) + self.confidence_head_with_markov = bool( + getattr(config, "confidence_head_with_markov", True) + ) + + self.markov_head = build_markov_head(config) + + self.confidence_head: Optional[nn.Module] = None + if getattr(config, "enable_confidence_head", False): + conf_input_dim = self.hidden_size + if self.confidence_head_with_markov: + if self.markov_head is None: + raise ValueError( + "confidence_head_with_markov=True requires a Markov head (markov_rank > 0)." + ) + conf_input_dim += self.markov_rank + self.confidence_head = AcceptRatePredictor(conf_input_dim) diff --git a/torchspec/train_entry.py b/torchspec/train_entry.py index 3b0189f..7848a54 100644 --- a/torchspec/train_entry.py +++ b/torchspec/train_entry.py @@ -250,14 +250,19 @@ def _validate_and_configure_dflash(args, draft_model_config) -> None: Called before dataset loading to fail fast on misconfigurations. """ from torchspec.models.draft.dflash import DFlashConfig + from torchspec.models.draft.dspark import DSparkConfig if not isinstance(draft_model_config, DFlashConfig): return + # DSparkConfig subclasses DFlashConfig + is_dspark = isinstance(draft_model_config, DSparkConfig) + algo = "DSpark" if is_dspark else "DFlash" + engine_type = getattr(args, "inference_engine_type", "hf") if engine_type not in ("vllm", "sgl"): raise NotImplementedError( - f"DFlash supports inference_engine_type in ('vllm', 'sgl'), got '{engine_type}'." + f"{algo} supports inference_engine_type in ('vllm', 'sgl'), got '{engine_type}'." ) if getattr(args, "defer_tokenization", False): raise NotImplementedError("DFlash does not support defer_tokenization=True.") @@ -265,7 +270,7 @@ def _validate_and_configure_dflash(args, draft_model_config) -> None: min_loss = getattr(args, "min_loss_tokens", 0) if min_loss < 2 * block_size: raise ValueError( - f"DFlash requires dataset.min_loss_tokens >= 2 * training.dflash_block_size " + f"{algo} requires dataset.min_loss_tokens >= 2 * training.dflash_block_size " f"({min_loss} < {2 * block_size}). Set dataset.min_loss_tokens={2 * block_size}." ) diff --git a/torchspec/training/dflash_trainer.py b/torchspec/training/dflash_trainer.py index feff5c8..264d76c 100644 --- a/torchspec/training/dflash_trainer.py +++ b/torchspec/training/dflash_trainer.py @@ -27,7 +27,7 @@ import torch.distributed as dist from torchspec.models.dflash import DFlashModel -from torchspec.models.draft.dflash import DFlashDraftModel +from torchspec.models.draft.dflash import DFlashConfig, DFlashDraftModel from torchspec.training import checkpoint from torchspec.training.fsdp import apply_fsdp2, fsdp2_load_full_state_dict from torchspec.training.optimizer import BF16Optimizer @@ -41,8 +41,30 @@ class DFlashTrainer(Trainer): Extends ``Trainer`` with DFlash model initialisation (dual-source KV draft model), forward/backward with anchor sampling + block-causal mask, and metric aggregation. + + DSparkTrainer is a subclass that overrides the build hooks: + - `_draft_config_class` + - `_build_draft_model` + - `_build_training_wrapper` """ + _draft_config_class = DFlashConfig + + def _build_draft_model(self, config): + """Instantiate the draft network. Overridden by subclasses.""" + return DFlashDraftModel(config) + + def _build_training_wrapper(self, draft_model): + """Wrap the draft network with the training-objective module.""" + return DFlashModel( + draft_model=draft_model, + block_size=self.block_size, + num_anchors=self.num_anchors, + loss_objective=self.loss_objective, + dpace_alpha=self.dpace_alpha, + loss_decay_gamma=self.loss_decay_gamma, + ) + def __init__(self, args: Namespace): super().__init__(args) self.target_lm_head: Optional[torch.nn.Module] = None @@ -52,6 +74,10 @@ def __init__(self, args: Namespace): self.loss_objective = getattr(args, "dflash_loss_objective", "decay") self.dpace_alpha = getattr(args, "dflash_dpace_alpha", 0.5) self.loss_decay_gamma = getattr(args, "dflash_loss_decay_gamma", 7.0) + # Number of leading within-block slots to drop from per-position metrics. + # DFlash slot 0 is the masked anchor (drop it); DSpark predicts at every + # slot (overrides to 0). + self._anchor_slot_offset = 1 def init_model( self, @@ -71,18 +97,18 @@ def init_model( init_context = self._get_init_weight_context_manager() with init_context(): - from torchspec.models.draft.dflash import DFlashConfig + cfg_cls = self._draft_config_class if isinstance(draft_model_config, str): - config = DFlashConfig.from_pretrained(draft_model_config) + config = cfg_cls.from_pretrained(draft_model_config) elif isinstance(draft_model_config, dict): - config = DFlashConfig(**draft_model_config) + config = cfg_cls(**draft_model_config) elif isinstance(draft_model_config, DFlashConfig): config = draft_model_config else: raise TypeError( f"Unsupported draft_model_config type: {type(draft_model_config).__name__}. " - f"Expected str, dict, or DFlashConfig." + f"Expected str, dict, or {cfg_cls.__name__}." ) if not hasattr(config, "num_target_layers") or config.num_target_layers is None: @@ -101,7 +127,7 @@ def init_model( ) config.target_num_hidden_layers = target_config.num_hidden_layers - draft_model = DFlashDraftModel(config) + draft_model = self._build_draft_model(config) if dist.get_rank() == 0: draft_model.load_embedding( @@ -121,14 +147,7 @@ def init_model( f"{frozen_count:,} frozen (embedding) parameters" ) - dflash_model = DFlashModel( - draft_model=draft_model, - block_size=self.block_size, - num_anchors=self.num_anchors, - loss_objective=self.loss_objective, - dpace_alpha=self.dpace_alpha, - loss_decay_gamma=self.loss_decay_gamma, - ) + dflash_model = self._build_training_wrapper(draft_model) full_state = dflash_model.state_dict() if dist.get_rank() == 0 else {} @@ -380,9 +399,9 @@ def _aggregate_eval_metrics(self, all_step_metrics: list[dict]) -> dict: ) # Drop anchor slot (index 0) — see _aggregate_metrics for rationale. - pred_loss_pp = avg_loss_pp[1:] - pred_acc_pp = avg_acc_pp[1:] - pred_count_pp = count_pp[1:] + pred_loss_pp = avg_loss_pp[self._anchor_slot_offset :] + pred_acc_pp = avg_acc_pp[self._anchor_slot_offset :] + pred_count_pp = count_pp[self._anchor_slot_offset :] cumulative = 1.0 simulated_acc_len = 0.0 @@ -465,9 +484,9 @@ def _aggregate_metrics( # Skip index 0 (anchor slot, always zero); indices 1..B-1 are the # predicted tokens at 1..B-1 steps past the anchor. Re-index to 0..B-2 # so the naming matches Eagle3 (acc_0 = first predicted token). - pred_loss_pp = avg_loss_pp[1:] - pred_acc_pp = avg_acc_pp[1:] - pred_count_pp = count_pp[1:] + pred_loss_pp = avg_loss_pp[self._anchor_slot_offset :] + pred_acc_pp = avg_acc_pp[self._anchor_slot_offset :] + pred_count_pp = count_pp[self._anchor_slot_offset :] # Simulated accepted length: acc_0 + acc_0*acc_1 + ... + prod(acc_0..acc_{B-2}) # Models the expected number of consecutively accepted draft tokens. diff --git a/torchspec/training/dspark_trainer.py b/torchspec/training/dspark_trainer.py new file mode 100644 index 0000000..7407b69 --- /dev/null +++ b/torchspec/training/dspark_trainer.py @@ -0,0 +1,104 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""DSpark trainer — DFlash trainer + Markov/confidence heads and L1 distillation. + +Reuses the entire DFlash training pipeline (FSDP init, optimizer, checkpoint, +metric aggregation, hidden-state capture/transfer) via subclass hooks, and +additionally feeds the target ``last_hidden_states`` into the forward so the +L1 distribution-distillation and confidence-head losses can be computed. +""" + +from argparse import Namespace +from typing import Tuple + +import torch + +from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel +from torchspec.models.dspark import DSparkModel +from torchspec.training.dflash_trainer import DFlashTrainer + + +class DSparkTrainer(DFlashTrainer): + """DSpark-specific trainer (DFlash backbone + EAGLE-style heads).""" + + _draft_config_class = DSparkConfig + + def __init__(self, args: Namespace): + super().__init__(args) + # DSpark uses its own knobs; override the dflash_* defaults read by the + # parent so the shared init_model / wrapper builders pick them up. + self.block_size = getattr(args, "dflash_block_size", 7) + self.num_anchors = getattr(args, "dspark_num_anchors", 512) + self.num_target_layers = getattr(args, "dspark_num_target_layers", 5) + self.loss_decay_gamma = getattr(args, "dspark_loss_decay_gamma", 4.0) + self.ce_loss_alpha = getattr(args, "dspark_ce_loss_alpha", 0.1) + self.l1_loss_alpha = getattr(args, "dspark_l1_loss_alpha", 0.9) + self.confidence_head_alpha = getattr(args, "dspark_confidence_head_alpha", 1.0) + self._anchor_slot_offset = 0 + + # ------------------------------------------------------------------ + # Build hooks (override DFlashTrainer's defaults) + # ------------------------------------------------------------------ + + def _build_draft_model(self, config): + return DSparkDraftModel(config) + + def _build_training_wrapper(self, draft_model): + return DSparkModel( + draft_model=draft_model, + block_size=self.block_size, + num_anchors=self.num_anchors, + loss_decay_gamma=self.loss_decay_gamma, + ce_loss_alpha=self.ce_loss_alpha, + l1_loss_alpha=self.l1_loss_alpha, + confidence_head_alpha=self.confidence_head_alpha, + ) + + # ------------------------------------------------------------------ + # Forward — adds target last_hidden_states for L1 / confidence losses + # ------------------------------------------------------------------ + + def _forward( + self, batch: dict + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + device = torch.device("cuda") + input_ids = batch["input_ids"].to(device, non_blocking=True) + hidden_states = batch["hidden_states"].to(device, non_blocking=True) + + loss_mask = batch["loss_mask"] + if loss_mask.dim() == 3: + loss_mask = loss_mask.squeeze(-1) + loss_mask = loss_mask.to(device, non_blocking=True) + + last_hidden_states = batch.get("last_hidden_states", None) + if last_hidden_states is not None: + last_hidden_states = last_hidden_states.to(device, non_blocking=True) + + hidden_states_list = self._split_hidden_states(hidden_states) + del hidden_states + + return self.model( + input_ids=input_ids, + hidden_states_list=hidden_states_list, + loss_mask=loss_mask, + lm_head_weight=self.target_lm_head_weight, + last_hidden_states=last_hidden_states, + ) diff --git a/torchspec/training/trainer_actor.py b/torchspec/training/trainer_actor.py index 09fc38d..a9ce645 100644 --- a/torchspec/training/trainer_actor.py +++ b/torchspec/training/trainer_actor.py @@ -26,6 +26,7 @@ from torchspec import AutoDraftModelConfig from torchspec.models.draft.dflash import DFlashConfig +from torchspec.models.draft.dspark import DSparkConfig from torchspec.ray.ray_actor import RayActor from torchspec.training.eagle3_trainer import Eagle3Trainer from torchspec.utils.distributed import init_gloo_group, init_usp_groups @@ -75,8 +76,16 @@ def init(self, args: Namespace, role: str, mooncake_config=None, with_ref: bool if draft_model_config is None and getattr(args, "draft_model_config", None): draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) - # Config-based trainer dispatch: DFlashConfig → DFlashTrainer, else Eagle3 - if isinstance(draft_model_config, DFlashConfig): + # Config-based trainer dispatch. + # DSparkConfig subclasses DFlashConfig, so it must be checked first. + # - DSparkConfig → DSparkTrainer + # - DFlashConfig → DFlashTrainer + # - else Eagle3. + if isinstance(draft_model_config, DSparkConfig): + from torchspec.training.dspark_trainer import DSparkTrainer + + self._trainer = DSparkTrainer(args) + elif isinstance(draft_model_config, DFlashConfig): from torchspec.training.dflash_trainer import DFlashTrainer self._trainer = DFlashTrainer(args) From 6a62189b5d1fa6ed6458b89cd2c46f35350435e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Sat, 27 Jun 2026 21:18:09 +0000 Subject: [PATCH 2/4] add DSpec forward/loss MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Doğaç Eldenk --- torchspec/models/dspark.py | 282 +++++++++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 torchspec/models/dspark.py diff --git a/torchspec/models/dspark.py b/torchspec/models/dspark.py new file mode 100644 index 0000000..56d2afb --- /dev/null +++ b/torchspec/models/dspark.py @@ -0,0 +1,282 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""DSpark training model: DFlash training wrapper + Markov / L1 / confidence losses. + +Reuses :class:`DFlashModel`'s anchor sampling, block-causal FlexAttention mask, +and MASK-token noise construction verbatim, then layers on the DSpark training +objective: + + - Markov-biased draft logits (teacher-forced previous token). + - Cross-entropy against the ground-truth next tokens (hard labels). + - L1 distribution distillation: ``|softmax(draft) - softmax(target)|`` where the + target distribution is the frozen LM head applied to the *target's* final + hidden state at the aligned position (requires ``last_hidden_states``). + - Confidence head BCE against the empirical per-token accept rate. + +Combined: ``ce_alpha*ce + l1_alpha*l1 + confidence_alpha*confidence``. + +Loss formulation adapted from DeepSeek's DeepSpec +(deepspec/modeling/dspark/loss.py, MIT), including its pooled global-mean +reduction: local numerators over a cross-rank all-reduced denominator, scaled +by world_size to cancel FSDP's mean gradient reduction. +""" + +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from torchspec.models.dflash import DFlashModel, _create_dflash_mask_mod +from torchspec.models.ops.flex_attention import compile_friendly_create_block_mask + + +class DSparkModel(DFlashModel): + """DSpark training wrapper (DFlash backbone + Markov/L1/confidence heads).""" + + def __init__( + self, + draft_model, + block_size: int = 7, + num_anchors: int = 512, + loss_decay_gamma: float = 4.0, + ce_loss_alpha: float = 0.1, + l1_loss_alpha: float = 0.9, + confidence_head_alpha: float = 1.0, + ): + # Reuse DFlash anchor/mask/noise machinery. The "decay" objective drives + # the per-within-block position weighting shared by all DSpark terms. + super().__init__( + draft_model=draft_model, + block_size=block_size, + num_anchors=num_anchors, + loss_objective="decay", + dpace_alpha=0.5, + loss_decay_gamma=loss_decay_gamma, + ) + self.ce_loss_alpha = float(ce_loss_alpha) + self.l1_loss_alpha = float(l1_loss_alpha) + self.confidence_head_alpha = float(confidence_head_alpha) + + def _decay_weights(self, device: torch.device) -> torch.Tensor: + """exp(-k/gamma) over within-block position k (DeepSpec convention). + + Every slot 0..B-1 is a real prediction here, so slot 0 (the first + predicted token) gets weight 1.0 and later slots decay. + """ + k = torch.arange(self.block_size, device=device).view(1, 1, -1) + if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0: + return torch.exp(-k.float() / self.loss_decay_gamma) + return torch.ones_like(k, dtype=torch.float32) + + def forward( + self, + input_ids: torch.Tensor, + hidden_states_list: List[torch.Tensor], + loss_mask: torch.Tensor, + lm_head_weight: torch.Tensor, + last_hidden_states: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """DSpark training forward. + + Returns the same 5-tuple as :meth:`DFlashModel.forward` + (loss, accuracy, loss_per_position, acc_per_position, count_per_position) + so the DFlash trainer's metric aggregation is reused unchanged. ``loss`` + is the combined ce+l1+confidence objective; the per-position metrics are + cross-entropy based (acceptance proxy), matching DFlash semantics. + """ + bsz, seq_len = input_ids.shape + device = input_ids.device + + # ---- DFlash backbone (identical to DFlashModel.forward steps 1-7) ---- + context_feature = self.draft_model.extract_context_feature(hidden_states_list) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + noise_embedding = self._create_noise_embed(input_ids, anchor_positions, block_keep_mask) + context_position_ids, draft_position_ids = self._create_position_ids( + anchor_positions, seq_len + ) + + draft_len = n_blocks * self.block_size + kv_len = seq_len + draft_len + block_mask = None + if device.type == "cuda": + mask_mod = _create_dflash_mask_mod( + anchor_positions=anchor_positions, + block_keep_mask=block_keep_mask, + ctx_len=seq_len, + block_size=self.block_size, + ) + block_mask = compile_friendly_create_block_mask( + mask_mod=mask_mod, + B=bsz, + H=None, + Q_LEN=draft_len, + KV_LEN=kv_len, + device=device, + ) + + draft_hidden = self.draft_model( + draft_input_ids=None, + context_feature=context_feature, + draft_position_ids=draft_position_ids, + context_position_ids=context_position_ids, + block_mask=block_mask, + noise_embedding=noise_embedding, + ) + hidden_4d = draft_hidden.view(bsz, n_blocks, self.block_size, -1) + + base_logits = F.linear(draft_hidden, lm_head_weight) + base_logits_4d = base_logits.view(bsz, n_blocks, self.block_size, -1) + vocab_size = base_logits_4d.size(-1) + + # ---- Labels + eval mask (DeepSpec convention) ---- + # Slot j predicts the token at anchor+j+1 (the real anchor token seeds + # slot 0). All block_size slots are supervised — there is no masked + # anchor slot, unlike DFlash. + label_offsets = torch.arange(1, self.block_size + 1, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets # [B, nb, bs] + valid_label_mask = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + safe_label_indices = torch.where( + block_keep_mask.unsqueeze(-1), + safe_label_indices, + torch.zeros_like(safe_label_indices), + ) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) # [B, nb, bs] + + # eval mask = contiguous supervised prefix per block (DeepSpec + # build_eval_mask): block kept, label in-bounds, target token supervised, + # then cumprod so a gap truncates the rest of the block. + target_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + eval_bool = block_keep_mask.unsqueeze(-1) & valid_label_mask & (target_loss_mask > 0.5) + eval_bool = eval_bool.to(torch.int32).cumprod(dim=-1).bool() + eval_mask = eval_bool.float() # [B, nb, bs] + + decay_weight_mask = eval_mask * self._decay_weights(device) + local_den = decay_weight_mask.sum() + + # ---- Markov-biased draft logits ---- + # prev token for slot j is the ground-truth token immediately before the + # one slot j predicts: slot 0's prev is the real anchor token, slot j's + # is target_ids[j-1]. Matches DeepSpec prev_token_ids. + anchor_token_ids = torch.gather(input_ids, 1, anchor_positions) # [B, nb] + prev_token_ids = torch.cat([anchor_token_ids.unsqueeze(-1), target_ids[:, :, :-1]], dim=-1) + logits_4d = base_logits_4d + if self.draft_model.markov_head is not None: + logits_4d = self.draft_model.markov_head.apply_block_logits( + base_logits_4d, token_ids=prev_token_ids + ) + + # ---- Cross entropy (hard labels) ---- + flat_logits = logits_4d.reshape(-1, vocab_size) + flat_targets = target_ids.reshape(-1) + ce_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none").view( + bsz, n_blocks, self.block_size + ) + ce_num = (ce_per_token * decay_weight_mask).sum() + + # ---- L1 distribution distillation + accept rate ---- + l1_num = base_logits.new_zeros((), dtype=torch.float32) + accept_rate = None + need_target = (self.l1_loss_alpha > 0) or ( + self.draft_model.confidence_head is not None and self.confidence_head_alpha > 0 + ) + if need_target: + if last_hidden_states is None: + raise ValueError( + "DSpark L1/confidence losses require target last_hidden_states; set " + "inference.store_last_hidden_states=true in the run config." + ) + # target distribution for the token at label_indices = target LM head + # applied to the target hidden one position earlier (anchor+j). + tgt_idx = (safe_label_indices - 1).clamp(min=0) # [B, nb, bs] + hdim = last_hidden_states.size(-1) + gather_idx = tgt_idx.reshape(bsz, -1, 1).expand(-1, -1, hdim) + aligned_hidden = torch.gather(last_hidden_states, 1, gather_idx) + aligned_target_logits = F.linear(aligned_hidden, lm_head_weight).view( + bsz, n_blocks, self.block_size, vocab_size + ) + draft_probs = torch.softmax(logits_4d.float(), dim=-1) + target_probs = torch.softmax(aligned_target_logits.float(), dim=-1) + l1_per_token = (draft_probs - target_probs).abs().sum(dim=-1) # [B, nb, bs] + if self.l1_loss_alpha > 0: + l1_num = (l1_per_token * decay_weight_mask).sum() + accept_rate = (1.0 - 0.5 * l1_per_token).clamp(0.0, 1.0) + + # ---- Confidence head BCE ---- + conf_num = base_logits.new_zeros((), dtype=torch.float32) + if self.draft_model.confidence_head is not None and self.confidence_head_alpha > 0: + if self.draft_model.confidence_head_with_markov: + prev_emb = self.draft_model.markov_head.get_prev_embeddings(prev_token_ids).to( + hidden_4d.dtype + ) + conf_features = torch.cat([hidden_4d, prev_emb], dim=-1) + else: + conf_features = hidden_4d + confidence_pred = self.draft_model.confidence_head(conf_features).float() + conf_bce = ( + F.binary_cross_entropy_with_logits( + confidence_pred, accept_rate.detach(), reduction="none" + ) + * decay_weight_mask + ) + conf_num = conf_bce.sum() + + # ---- Pooled global loss (DeepSpec _build_loss) ---- + # Local numerators over a cross-rank-summed denominator, x world_size to + # cancel FSDP's mean gradient reduction -> a true token-pooled global mean + # rather than a mean-of-per-rank-means. + # NOTE: uses the global training group size; correct for plain DP. With a + # multi-dim mesh (e.g. USP) the FSDP shard group differs from world_size. + world_size = dist.get_world_size() if dist.is_initialized() else 1 + global_den = local_den.detach().clone() + if world_size > 1: + dist.all_reduce(global_den, op=dist.ReduceOp.SUM) + global_den = global_den + 1e-6 + loss = ( + self.ce_loss_alpha * ce_num / global_den + + self.l1_loss_alpha * l1_num / global_den + + self.confidence_head_alpha * conf_num / global_den + ) * world_size + + # ---- Metrics (cross-entropy based; all block_size slots are productive) ---- + with torch.no_grad(): + flat_binary = eval_mask.reshape(-1) + pred_ids = torch.argmax(flat_logits, dim=-1) + correct = (pred_ids == flat_targets) & (flat_binary > 0.5) + accuracy = correct.sum().float() / flat_binary.sum().clamp(min=1e-6) + + count_per_position = eval_mask.sum(dim=(0, 1)) + count_pp = count_per_position.clamp(min=1.0) + loss_per_position = (ce_per_token * eval_mask).sum(dim=(0, 1)) / count_pp + acc_per_position = ( + correct.view(bsz, n_blocks, self.block_size).float().sum(dim=(0, 1)) / count_pp + ) + + return loss, accuracy, loss_per_position, acc_per_position, count_per_position From 7548b79e8bd9218da274cfb41dc5d3fd9198efb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Sat, 27 Jun 2026 22:05:16 +0000 Subject: [PATCH 3/4] add more tests and report more loss MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Doğaç Eldenk --- tests/test_dspark.py | 255 +++++++++++++++++++++++++++ torchspec/models/dspark.py | 32 +++- torchspec/training/dflash_trainer.py | 5 +- torchspec/training/dspark_trainer.py | 48 ++++- 4 files changed, 328 insertions(+), 12 deletions(-) create mode 100644 tests/test_dspark.py diff --git a/tests/test_dspark.py b/tests/test_dspark.py new file mode 100644 index 0000000..119f1cc --- /dev/null +++ b/tests/test_dspark.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Tests for DSpark (DFlash backbone + Markov / confidence heads + L1 distillation). + +Pins the DSpark wiring so future refactors can't silently break the objective: + +1. DSparkConfig / DSparkDraftModel: head construction, subclass relationship. +2. forward returns the 6-tuple with detached per-component losses. +3. Loss-wiring invariants (no DeepSpec dependency): + - internal identity: combined loss == ce_a*ce + l1_a*l1 + cf_a*conf (so the + logged loss_components are trustworthy) + - all-masked batch -> loss 0 + - gradients reach markov + confidence + backbone; embedding stays frozen + - next-token convention: every within-block slot is supervised (B predictions) +4. Markov / confidence head unit math. +5. Algorithm dispatch (DSparkConfig resolves from the JSON and is checked before + DFlashConfig since it subclasses it). +""" + +import unittest + +import torch + +from torchspec.models.draft.auto import AutoDraftModelConfig +from torchspec.models.draft.dflash import DFlashConfig +from torchspec.models.draft.dspark import ( + AcceptRatePredictor, + DSparkConfig, + DSparkDraftModel, + VanillaMarkov, +) +from torchspec.models.dspark import DSparkModel + +CE_A, L1_A, CF_A = 0.1, 0.9, 1.0 + + +def _make_dspark_config( + H=64, + V=128, + num_target_layers=2, + markov_rank=16, + enable_confidence_head=True, + confidence_head_with_markov=True, +): + return DSparkConfig( + hidden_size=H, + intermediate_size=256, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + vocab_size=V, + rms_norm_eps=1e-6, + max_position_embeddings=512, + rope_theta=10000.0, + num_target_layers=num_target_layers, + target_hidden_size=H, + target_num_hidden_layers=12, + mask_token_id=V - 1, + markov_rank=markov_rank, + markov_head_type="vanilla", + enable_confidence_head=enable_confidence_head, + confidence_head_with_markov=confidence_head_with_markov, + ) + + +def _make_dspark_model(block_size=4, num_anchors=6, **cfg_kw): + config = _make_dspark_config(**cfg_kw) + draft = DSparkDraftModel(config).to(dtype=torch.float32) + draft.freeze_embedding() + return DSparkModel( + draft_model=draft, + block_size=block_size, + num_anchors=num_anchors, + loss_decay_gamma=4.0, + ce_loss_alpha=CE_A, + l1_loss_alpha=L1_A, + confidence_head_alpha=CF_A, + ) + + +def _batch(B=2, S=24, H=64, V=128, num_target_layers=2, all_masked=False, seed=0): + g = torch.Generator().manual_seed(seed) + input_ids = torch.randint(0, V, (B, S), generator=g) + hidden_states_list = [torch.randn(B, S, H, generator=g) for _ in range(num_target_layers)] + loss_mask = torch.zeros(B, S) if all_masked else torch.ones(B, S) + if not all_masked: + loss_mask[:, :2] = 0 # prompt + lm_head_weight = torch.randn(V, H, generator=g) + last_hidden_states = torch.randn(B, S, H, generator=g) + return dict( + input_ids=input_ids, + hidden_states_list=hidden_states_list, + loss_mask=loss_mask, + lm_head_weight=lm_head_weight, + last_hidden_states=last_hidden_states, + ) + + +class TestDSparkConfig(unittest.TestCase): + def test_subclasses_dflash_and_attrs(self): + cfg = _make_dspark_config(markov_rank=32) + self.assertIsInstance(cfg, DFlashConfig) # ordering hazard: check DSpark first + self.assertEqual(cfg.model_type, "dspark") + self.assertEqual(cfg.markov_rank, 32) + self.assertTrue(cfg.enable_confidence_head) + + def test_draft_model_heads(self): + cfg = _make_dspark_config(H=64, markov_rank=16) + m = DSparkDraftModel(cfg) + self.assertIsInstance(m.markov_head, VanillaMarkov) + self.assertIsInstance(m.confidence_head, AcceptRatePredictor) + # confidence input = hidden + markov_rank when fused + self.assertEqual(m.confidence_head.proj.in_features, 64 + 16) + + def test_no_heads(self): + cfg = _make_dspark_config( + markov_rank=0, enable_confidence_head=False, confidence_head_with_markov=False + ) + m = DSparkDraftModel(cfg) + self.assertIsNone(m.markov_head) + self.assertIsNone(m.confidence_head) + + +class TestDSparkForward(unittest.TestCase): + def test_returns_six_tuple_with_detached_components(self): + m = _make_dspark_model() + out = m(**_batch()) + self.assertEqual(len(out), 6) + loss, acc, lpp, app, cpp, comps = out + self.assertEqual(set(comps), {"ce_loss", "l1_loss", "confidence_loss"}) + for v in comps.values(): + self.assertTrue(torch.isfinite(v).all()) + self.assertFalse(v.requires_grad) # detached for logging + self.assertTrue(torch.isfinite(loss)) + self.assertEqual(lpp.shape[0], m.block_size) + + def test_internal_loss_identity(self): + # At world_size==1 (no process group), the combined loss must equal the + # alpha-weighted sum of the logged components — so the components are a + # faithful decomposition of what's actually optimized. + m = _make_dspark_model() + loss, _, _, _, _, comps = m(**_batch(seed=1)) + recomputed = ( + CE_A * comps["ce_loss"] + L1_A * comps["l1_loss"] + CF_A * comps["confidence_loss"] + ) + self.assertTrue( + torch.allclose(loss, recomputed, atol=1e-4), f"{loss.item()} vs {recomputed.item()}" + ) + + def test_all_masked_is_zero(self): + m = _make_dspark_model() + loss, _, _, _, _, comps = m(**_batch(all_masked=True)) + self.assertAlmostEqual(loss.item(), 0.0, places=5) + for v in comps.values(): + self.assertAlmostEqual(v.item(), 0.0, places=5) + + def test_next_token_convention_all_slots_supervised(self): + # Fix 1: every within-block slot predicts a real token (B predictions), + # unlike DFlash where slot 0 is the masked anchor. With a long fully + # supervised sequence, every position should accumulate supervised tokens. + m = _make_dspark_model(block_size=4, num_anchors=8) + b = _batch(B=2, S=40) + b["loss_mask"] = torch.ones(2, 40) + _, _, _, _, count_per_position, _ = m(**b) + self.assertEqual(count_per_position.shape[0], 4) + self.assertTrue( + (count_per_position > 0).all(), f"some slot unsupervised: {count_per_position.tolist()}" + ) + + def test_grad_flow_and_frozen_embedding(self): + m = _make_dspark_model() + loss, *_ = m(**_batch(seed=2)) + loss.backward() + draft = m.draft_model + self.assertIsNotNone(draft.markov_head.markov_w2.weight.grad) + self.assertGreater(draft.markov_head.markov_w2.weight.grad.abs().sum().item(), 0) + self.assertIsNotNone(draft.confidence_head.proj.weight.grad) + self.assertGreater(draft.confidence_head.proj.weight.grad.abs().sum().item(), 0) + self.assertIsNotNone(draft.context_proj.weight.grad) + self.assertIsNone(draft.embed_tokens.weight.grad) # frozen + + def test_ce_only_without_target(self): + # ce-only (l1=0, no confidence) must run without last_hidden_states. + m = _make_dspark_model( + markov_rank=16, enable_confidence_head=False, confidence_head_with_markov=False + ) + m.l1_loss_alpha = 0.0 + m.ce_loss_alpha = 1.0 + m.confidence_head_alpha = 0.0 + b = _batch() + b["last_hidden_states"] = None + loss, *_ = m(**b) + self.assertTrue(torch.isfinite(loss)) + + +class TestHeadMath(unittest.TestCase): + def test_vanilla_markov_is_bigram_bias(self): + torch.manual_seed(0) + mk = VanillaMarkov(vocab_size=50, markov_rank=8) + base = torch.randn(2, 3, 4, 50) + prev = torch.randint(0, 50, (2, 3, 4)) + out = mk.apply_block_logits(base, token_ids=prev) + expected = base + mk.markov_w2(mk.markov_w1(prev)) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + + def test_confidence_head_is_linear(self): + torch.manual_seed(0) + head = AcceptRatePredictor(20) + feats = torch.randn(2, 3, 4, 20) + out = head(feats) + expected = head.proj(feats).squeeze(-1) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + self.assertEqual(out.shape, (2, 3, 4)) + + +class TestDispatch(unittest.TestCase): + def test_json_resolves_to_dspark_config(self): + cfg = AutoDraftModelConfig.from_dict( + { + "architectures": ["DSparkDraftModel"], + "model_type": "dspark", + "hidden_size": 64, + "vocab_size": 128, + "num_hidden_layers": 1, + "num_target_layers": 2, + "markov_rank": 16, + "enable_confidence_head": True, + } + ) + self.assertIsInstance(cfg, DSparkConfig) + # Subclass of DFlashConfig -> any isinstance(DFlashConfig) dispatch must + # test DSparkConfig first (trainer_actor / train_entry rely on this). + self.assertIsInstance(cfg, DFlashConfig) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchspec/models/dspark.py b/torchspec/models/dspark.py index 56d2afb..4fd03a9 100644 --- a/torchspec/models/dspark.py +++ b/torchspec/models/dspark.py @@ -94,14 +94,16 @@ def forward( loss_mask: torch.Tensor, lm_head_weight: torch.Tensor, last_hidden_states: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: """DSpark training forward. - Returns the same 5-tuple as :meth:`DFlashModel.forward` - (loss, accuracy, loss_per_position, acc_per_position, count_per_position) - so the DFlash trainer's metric aggregation is reused unchanged. ``loss`` - is the combined ce+l1+confidence objective; the per-position metrics are - cross-entropy based (acceptance proxy), matching DFlash semantics. + Returns DFlashModel.forward's 5-tuple (loss, accuracy, loss_per_position, + acc_per_position, count_per_position) plus a 6th element: a dict of + detached per-component loss scalars (ce_loss / l1_loss / confidence_loss, + per-rank local means) for logging. ``loss`` is the combined + ce+l1+confidence objective; the per-position metrics are cross-entropy + based (acceptance proxy). DSparkTrainer unpacks the 6th element; the rest + of DFlash's aggregation is reused unchanged. """ bsz, seq_len = input_ids.shape device = input_ids.device @@ -265,6 +267,15 @@ def forward( + self.confidence_head_alpha * conf_num / global_den ) * world_size + # Per-component loss values (per-rank local means) for logging only — + # lets you watch L1 fall while the greedy-CE proxy plateaus. + local_den_eps = local_den + 1e-6 + loss_components = { + "ce_loss": (ce_num / local_den_eps).detach(), + "l1_loss": (l1_num / local_den_eps).detach(), + "confidence_loss": (conf_num / local_den_eps).detach(), + } + # ---- Metrics (cross-entropy based; all block_size slots are productive) ---- with torch.no_grad(): flat_binary = eval_mask.reshape(-1) @@ -279,4 +290,11 @@ def forward( correct.view(bsz, n_blocks, self.block_size).float().sum(dim=(0, 1)) / count_pp ) - return loss, accuracy, loss_per_position, acc_per_position, count_per_position + return ( + loss, + accuracy, + loss_per_position, + acc_per_position, + count_per_position, + loss_components, + ) diff --git a/torchspec/training/dflash_trainer.py b/torchspec/training/dflash_trainer.py index 264d76c..195a819 100644 --- a/torchspec/training/dflash_trainer.py +++ b/torchspec/training/dflash_trainer.py @@ -49,6 +49,7 @@ class DFlashTrainer(Trainer): """ _draft_config_class = DFlashConfig + _anchor_slot_offset = 1 def _build_draft_model(self, config): """Instantiate the draft network. Overridden by subclasses.""" @@ -74,10 +75,6 @@ def __init__(self, args: Namespace): self.loss_objective = getattr(args, "dflash_loss_objective", "decay") self.dpace_alpha = getattr(args, "dflash_dpace_alpha", 0.5) self.loss_decay_gamma = getattr(args, "dflash_loss_decay_gamma", 7.0) - # Number of leading within-block slots to drop from per-position metrics. - # DFlash slot 0 is the masked anchor (drop it); DSpark predicts at every - # slot (overrides to 0). - self._anchor_slot_offset = 1 def init_model( self, diff --git a/torchspec/training/dspark_trainer.py b/torchspec/training/dspark_trainer.py index 7407b69..d24cc87 100644 --- a/torchspec/training/dspark_trainer.py +++ b/torchspec/training/dspark_trainer.py @@ -30,6 +30,7 @@ from typing import Tuple import torch +import torch.distributed as dist from torchspec.models.draft.dspark import DSparkConfig, DSparkDraftModel from torchspec.models.dspark import DSparkModel @@ -95,10 +96,55 @@ def _forward( hidden_states_list = self._split_hidden_states(hidden_states) del hidden_states - return self.model( + # DSparkModel.forward returns a 6th element: a dict of per-component loss + # scalars (ce/l1/confidence). Stash it for _train_step to log; return the + # 5-tuple the base trainer expects. + ( + loss, + accuracy, + loss_per_position, + acc_per_position, + count_per_position, + self._last_loss_components, + ) = self.model( input_ids=input_ids, hidden_states_list=hidden_states_list, loss_mask=loss_mask, lm_head_weight=self.target_lm_head_weight, last_hidden_states=last_hidden_states, ) + return loss, accuracy, loss_per_position, acc_per_position, count_per_position + + # ------------------------------------------------------------------ + # Per-component loss logging (ce / l1 / confidence) + # ------------------------------------------------------------------ + + def _train_step( + self, + batch: dict, + accumulation_steps: int, + step: int, + batch_idx: int, + num_batches: int, + ) -> dict: + metrics = super()._train_step(batch, accumulation_steps, step, batch_idx, num_batches) + # Carry the components from the forward that _train_step just ran. + for key, value in getattr(self, "_last_loss_components", {}).items(): + metrics[key] = value + return metrics + + def _aggregate_metrics( + self, all_step_metrics: list[dict], step: int, *, grad_norm: torch.Tensor = None + ) -> dict: + metrics = super()._aggregate_metrics(all_step_metrics, step, grad_norm=grad_norm) + if all_step_metrics: + for key in ("ce_loss", "l1_loss", "confidence_loss"): + vals = [m[key] for m in all_step_metrics if key in m] + if not vals: + continue + value = torch.stack([v.float() for v in vals]).mean() + if dist.is_initialized() and dist.get_world_size() > 1: + dist.all_reduce(value, op=dist.ReduceOp.SUM) + value = value / dist.get_world_size() + metrics[f"train/{key}"] = value.item() + return metrics From b74deb41756a5e656598d0925a4aaf1909178cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Sun, 28 Jun 2026 20:25:37 -0500 Subject: [PATCH 4/4] fix lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Doğaç Eldenk --- torchspec/train_entry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchspec/train_entry.py b/torchspec/train_entry.py index f320382..12415fd 100644 --- a/torchspec/train_entry.py +++ b/torchspec/train_entry.py @@ -245,7 +245,7 @@ def _get_draft_model_config(args): def _validate_and_configure_dflash(args, draft_model_config) -> None: - """Validate + """Validate -specific config and auto-set aux layer IDs. Called before dataset loading to fail fast on misconfigurations.