Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions configs/vllm_qwen3_8b_dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# DFlash training config for Qwen3-8B — vLLM inference backend
#
# GPU allocation (4x GPU; cf. vllm_qwen3_8b.yaml):
# - 2 GPUs for inference (vLLM, tp_size=2, single engine copy)
# - 2 GPUs for training (FSDP FULL_SHARD)
# - global_batch = micro_batch_size × dp_size × accum = 1 × 2 × 2 = 4
#
# Usage:
# python -m torchspec.train_entry --config configs/vllm_qwen3_8b_dflash.yaml \
# dataset.train_data_path=/path/to/data.jsonl output_dir=./outputs/dflash-vllm

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true
draft_model_config: torchspec/config/dflash_draft_config.json

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: null
eval_interval: 100
chat_template: qwen
prompt_key: conversations
min_loss_tokens: 32 # DFlash requires >= 2 * dflash_block_size (= 32)

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 2 # 2x optimizer steps for better convergence
learning_rate: 6e-4
min_lr: 6e-5 # 10% of peak — prevents LR death in later epochs
weight_decay: 0.01 # AdamW regularization for better generalization
max_concurrent_batches: 1
max_grad_norm: 1.0
max_seq_length: 2048
num_epochs: 3
seed: 42
training_num_gpus_per_node: 2
training_num_nodes: 1
ttt_length: 7
fsdp_strategy: FULL_SHARD
fsdp_reduce_dtype: bfloat16
prefetch_depth: 8
save_interval: 1000
save_per_epoch: true
max_checkpoints: 2
warmup_ratio: 0.04

# DFlash-specific parameters
dflash_block_size: 16
dflash_num_anchors: 512
dflash_loss_decay_gamma: 7.0
dflash_num_target_layers: 5

inference:
inference_engine_type: vllm
store_last_hidden_states: false # DFlash projects from the 5 aux layers, not the last hidden state
inference_num_gpus: 2
inference_num_gpus_per_engine: 2
inference_num_gpus_per_node: 4
max_sample_pool_size: 64
inference_buffer_threshold: 32
inference_batch_size: 8
vllm:
tp_size: 2
mem_fraction_static: 0.7
extra_args:
max_num_batched_tokens: 8192

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 32GB
local_buffer_size: 4GB
# Hard-pin: master-side TTL is disabled; we rely on our explicit
# batch_remove(force=True) (see mooncake/eagle_store.py). Requires
# mooncake-transfer-engine >= 0.3.10.post1.
enable_hard_pin: true

output_dir: ./outputs/qwen3-8b-dflash-vllm
cache_dir: ./cache/qwen3-8b-dflash-vllm
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
7 changes: 5 additions & 2 deletions torchspec/train_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,11 @@ def _validate_and_configure_dflash(args, draft_model_config) -> None:
if not isinstance(draft_model_config, DFlashConfig):
return

if getattr(args, "inference_engine_type", "hf") != "sgl":
raise NotImplementedError("DFlash currently supports only inference_engine_type='sgl'.")
engine_type = getattr(args, "inference_engine_type", "hf")
if engine_type not in ("vllm", "sgl"):
raise NotImplementedError(
f"DFlash supports inference_engine_type in ('vllm', 'sgl'), got '{engine_type}'."
)
if getattr(args, "defer_tokenization", False):
raise NotImplementedError("DFlash does not support defer_tokenization=True.")
block_size = getattr(args, "dflash_block_size", 16)
Expand Down
Loading