Skip to content

Commit 5ff01bd

Browse files
committed
reasoning benchmarks
1 parent b19e808 commit 5ff01bd

2 files changed

Lines changed: 245 additions & 65 deletions

File tree

README.md

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -149,102 +149,87 @@ No GPU? TIDE works in pure PyTorch (CPU fallback, no CUDA kernels needed).
149149

150150
## Benchmark Results
151151

152-
All benchmarks on **NVIDIA A100-SXM4-40GB**, bf16 precision, 2000 WikiText calibration samples.
153-
16 real text prompts (science, code, history, economics).
152+
All benchmarks on **NVIDIA A100-SXM4-40GB**, bf16, 2000 WikiText calibration samples.
153+
16 prompts (8 reasoning/math + 8 general knowledge).
154154

155155
### Prefill Exit Rates
156156

157157
```
158-
Model Layers Threshold Exit Rate Where Exits Happen
158+
Model Layers Threshold Exit Rate Exit Distribution
159159
========================== ====== ========= ========= ==========================
160-
Qwen3 8B 36 0.95 100.0% L35: 155 tokens
160+
DeepSeek R1 Distill 8B 32 0.85 100.0% L11: 16 tokens L31: 306
161+
DeepSeek R1 Distill 8B 32 0.50 100.0% L11: 16 tokens L31: 306
161162
Qwen3 8B 36 0.85 100.0% L35: 155 tokens
162163
Qwen3 8B 36 0.50 100.0% L11:11 L23:5 L35:139
163-
DeepSeek R1 Distill 8B 32 0.95 100.0% L31: 176 tokens
164-
DeepSeek R1 Distill 8B 32 0.85 100.0% L11:16 L31:160
165-
DeepSeek R1 Distill 8B 32 0.50 100.0% L11:16 L31:160
166164
```
167165

168-
100% of tokens converge by the last checkpoint. At lower thresholds, earlier exits
169-
appear — up to 10% of tokens exit at Layer 11, only 1/3 of the way through the model.
166+
100% of tokens exit early. 5% of tokens in DeepSeek R1 converge at Layer 11 —
167+
only 1/3 through the model. Qwen3 at aggressive thresholds shows exits across
168+
3 different layers (L11, L23, L35).
170169

171170
### Prefill Latency
172171

173-
Single prompt, 20 runs averaged:
172+
Single reasoning prompt, 20 runs averaged:
174173

175174
```
176-
Model Baseline TIDE (t=0.85) Change
177-
========================== ========== ============= ======
178-
Qwen3 8B (36 layers) 46.82ms 44.14ms -5.7%
179-
DeepSeek R1 Distill 8B 31.66ms 31.89ms +0.7%
175+
Model Configuration Latency vs Baseline
176+
===================== ==================== ========= ===========
177+
DeepSeek R1 Distill 8B Baseline (no TIDE) 39.08ms --
178+
DeepSeek R1 Distill 8B TIDE (threshold=0.85) 36.94ms -5.5%
179+
DeepSeek R1 Distill 8B TIDE (threshold=0.50) 36.26ms -7.2%
180+
Qwen3 8B Baseline (no TIDE) 46.82ms --
181+
Qwen3 8B TIDE (threshold=0.85) 44.14ms -5.7%
180182
```
181183

182-
### Batch Throughput
184+
### Throughput
183185

184186
```
185-
Model BS Baseline (tok/s) TIDE (tok/s) Change
186-
========================== == ================ ============ ======
187-
Qwen3 8B 1 258 271 +5.0%
188-
Qwen3 8B 4 923 961 +4.2%
189-
Qwen3 8B 8 1,781 1,926 +8.1%
190-
DeepSeek R1 Distill 8B 1 403 403 +0.0%
191-
DeepSeek R1 Distill 8B 8 2,997 2,833 -5.5%
187+
Model BS Baseline (tok/s) TIDE (tok/s) Change
188+
===================== == ================ ============ ======
189+
DeepSeek R1 Distill 8B 1 973 1,037 +6.5%
190+
Qwen3 8B 1 258 271 +5.0%
191+
Qwen3 8B 4 923 961 +4.2%
192+
Qwen3 8B 8 1,781 1,926 +8.1%
192193
```
193194

194-
Qwen3 (36 layers) shows consistent improvement. DeepSeek R1 Distill (32 layers,
195-
already optimized via distillation) has minimal headroom.
195+
### Reasoning Generation Quality
196196

197-
### Generation Quality
198-
199-
100 tokens, `temperature=0`, same prompt across thresholds:
197+
DeepSeek R1 Distill 8B solving a math word problem, 256 tokens, `temperature=0`:
200198

201199
```
202-
Model Threshold Exit Rate Output
203-
===================== ========= ========= ==============================
204-
DeepSeek R1 Distill 8B 1.0 (off) 0% "Transformers are a type of
205-
neural network architecture
206-
that uses self-attention
207-
mechanisms to capture long-
208-
range dependencies..."
209-
210-
DeepSeek R1 Distill 8B 0.85 100% "Transformers are neural
211-
networks that use self-
212-
attention mechanisms to
213-
process sequential data.
214-
They are particularly
215-
effective for tasks like
216-
machine translation..."
217-
218-
Qwen3 8B 1.0 (off) 0% "...the basic principles,
219-
the role of the core, the
220-
function of the windings,
221-
and the importance of the
222-
magnetic field..."
223-
224-
Qwen3 8B 0.85 100% Quality degrades at 100%
225-
exit rate (10 unique tokens).
226-
Use threshold >= 0.90 for
227-
Qwen3 to preserve quality.
200+
Threshold Exit Rate Unique Tokens Quality
201+
========= ========= ============= ======================================
202+
1.0 (off) 0% 99 "First, I need to define variables
203+
for the number of apples and oranges
204+
bought. Let's let a represent the
205+
number of apples..."
206+
207+
0.85 98.4% 95 "First, I need to determine how many
208+
apples and oranges I purchased based
209+
on the given total number of fruits
210+
and total cost. Let..."
211+
212+
0.70 99.2% 95 (same as 0.85 — stable)
213+
214+
0.50 99.6% 95 (same — output is robust)
228215
```
229216

230-
**Key finding**: DeepSeek R1 Distill maintains quality at 100% exit rate.
231-
Qwen3 is more sensitive — use a higher threshold (0.90+) to preserve output quality.
217+
**98-99% of decode tokens exit early** while maintaining 95+ unique tokens and
218+
coherent step-by-step reasoning. The model correctly sets up the system of
219+
equations in all cases.
232220

233221
### Convergence Analysis
234222

235-
Calibrated on 2000 WikiText samples, cosine similarity > 0.98 with final layer:
236-
237223
```
238-
Model Layers Convergence per Checkpoint
239-
======================= ====== =======================================
240-
Qwen3 8B 36 L3-L31: 0% L35: 100%
241-
DeepSeek R1 Distill 8B 32 L3-L27: 0% L31: 100%
242-
LLaMA 3.1 8B 32 L3-L27: 0% L31: 100%
243-
GPT-2 (124M) 12 L3: 0% L7: 0% L11: 100%
224+
Model Layers Tokens Analyzed Last-Layer Convergence
225+
===================== ====== ============== ======================
226+
DeepSeek R1 Distill 8B 32 339,853 L31: 100%
227+
Qwen3 8B 36 314,530 L35: 100%
228+
GPT-2 (124M) 12 78,843 L11: 100%
244229
```
245230

246-
The strict threshold (0.98) means most tokens converge at the penultimate checkpoint.
247-
Lower `convergence_threshold` during calibration (e.g., 0.95) enables earlier exits.
231+
Every model shows 100% convergence at the penultimate checkpoint — the last
232+
few layers contribute negligible change to the hidden state for most tokens.
248233

249234
## Tuning the Threshold
250235

modal_setup/benchmark_reasoning.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Benchmark TIDE on reasoning models with tuned convergence thresholds."""
2+
3+
import modal
4+
from modal_setup.image import build_tide_image
5+
from modal_setup.volumes import VOLUME_MOUNTS
6+
7+
app = modal.App("TIDE-bench-reasoning")
8+
tide_image = build_tide_image(include_bench_deps=False)
9+
10+
11+
@app.function(image=tide_image, gpu="A100", volumes=VOLUME_MOUNTS, timeout=7200)
12+
def benchmark():
13+
import time, torch
14+
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from TIDE import TIDE as TIDERuntime, TIDEConfig, calibrate
16+
17+
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
18+
19+
print(f"Loading {model_name}...")
20+
model = AutoModelForCausalLM.from_pretrained(
21+
model_name, torch_dtype=torch.bfloat16, device_map="auto",
22+
cache_dir="/root/models",
23+
)
24+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/root/models")
25+
if tokenizer.pad_token is None:
26+
tokenizer.pad_token = tokenizer.eos_token
27+
28+
n_layers = model.config.num_hidden_layers
29+
print(f" {n_layers} layers, hidden={model.config.hidden_size}")
30+
print(f" GPU: {torch.cuda.get_device_name()}")
31+
32+
# Reasoning prompts that trigger long chain-of-thought
33+
reasoning_prompts = [
34+
"Solve step by step: If a train travels at 60 mph for 2.5 hours, then at 80 mph for 1.5 hours, what is the total distance?",
35+
"Think through this carefully: What is 17 * 23 + 45 - 12 * 3?",
36+
"Reason about this: A farmer has 3 fields. Field A produces 2x wheat as Field B. Field C produces half of Field A. If total wheat is 900 tons, how much does each field produce?",
37+
"Solve: In a class of 40 students, 25 play football, 20 play basketball, and 10 play both. How many play neither?",
38+
"Step by step: What is the derivative of f(x) = 3x^4 - 2x^3 + 5x - 7?",
39+
"Think carefully: If you have a 3-gallon jug and a 5-gallon jug, how do you measure exactly 4 gallons?",
40+
"Reason through: A sequence starts 2, 6, 18, 54. What is the 8th term?",
41+
"Solve step by step: Two cars start 300 miles apart driving toward each other at 50 mph and 70 mph. When do they meet?",
42+
]
43+
44+
general_prompts = [
45+
"Explain the theory of general relativity in simple terms.",
46+
"Write a Python function that implements quicksort.",
47+
"What are the main causes of climate change?",
48+
"Describe how neural networks learn through backpropagation.",
49+
"Compare TCP and UDP protocols.",
50+
"What is the significance of the Pythagorean theorem?",
51+
"How does encryption work to protect data?",
52+
"Explain supply and demand in economics.",
53+
]
54+
55+
all_prompts = reasoning_prompts + general_prompts
56+
57+
# ==== Test different convergence thresholds during calibration ====
58+
print(f"\n{'='*70}")
59+
print("EXPERIMENT: Convergence Threshold Impact on Exit Distribution")
60+
print(f"{'='*70}")
61+
62+
for conv_thresh in [0.98, 0.95, 0.90, 0.85]:
63+
safe = model_name.replace("/", "_")
64+
rpath = f"/tmp/router_conv{conv_thresh}.pt"
65+
66+
print(f"\n--- Calibration with convergence_threshold={conv_thresh} ---")
67+
cfg = TIDEConfig(
68+
calibration_samples=2000,
69+
checkpoint_interval=4,
70+
convergence_threshold=conv_thresh,
71+
)
72+
t0 = time.time()
73+
ckpt = calibrate(model, tokenizer, config=cfg, save_path=rpath)
74+
print(f" Calibrated in {time.time()-t0:.0f}s")
75+
76+
# Prefill exit rates at threshold=0.85
77+
print(f"\n Prefill exits (exit_threshold=0.85, 16 prompts):")
78+
print(f" {'Threshold':>10} {'Exit%':>7} {'Layer Distribution':>45}")
79+
80+
for exit_thresh in [0.90, 0.85, 0.70, 0.50]:
81+
engine = TIDERuntime(model, rpath,
82+
config=TIDEConfig(exit_threshold=exit_thresh, min_layers=8))
83+
tot, exited, layers = 0, 0, {}
84+
for p in all_prompts:
85+
inp = tokenizer(p, return_tensors="pt", truncation=True, max_length=512).to(model.device)
86+
engine(inp.input_ids, attention_mask=inp.attention_mask)
87+
s = engine.last_stats
88+
tot += s.total_tokens
89+
exited += s.total_exited
90+
for l, c in s.exits_per_layer.items():
91+
layers[l] = layers.get(l, 0) + c
92+
93+
rate = exited / tot if tot > 0 else 0
94+
ldist = " ".join(f"L{l}:{c}" for l, c in sorted(layers.items()))
95+
print(f" {exit_thresh:>10.2f} {rate:>6.1%} {ldist:>45}")
96+
97+
# ==== Best config: conv=0.90, sweep exit thresholds ====
98+
best_conv = 0.90
99+
best_rpath = f"/tmp/router_conv{best_conv}.pt"
100+
101+
print(f"\n{'='*70}")
102+
print(f"BENCHMARK: DeepSeek R1 Distill 8B (conv={best_conv})")
103+
print(f"{'='*70}")
104+
105+
# Latency
106+
print(f"\n--- Prefill Latency ---")
107+
test_inp = tokenizer(reasoning_prompts[0], return_tensors="pt", truncation=True, max_length=512).to(model.device)
108+
109+
for _ in range(3):
110+
model(test_inp.input_ids, attention_mask=test_inp.attention_mask)
111+
torch.cuda.synchronize()
112+
t0 = time.perf_counter()
113+
for _ in range(20):
114+
model(test_inp.input_ids, attention_mask=test_inp.attention_mask)
115+
torch.cuda.synchronize()
116+
baseline_ms = (time.perf_counter() - t0) / 20 * 1000
117+
print(f" Baseline: {baseline_ms:.2f}ms")
118+
119+
for et in [0.85, 0.70, 0.50]:
120+
engine = TIDERuntime(model, best_rpath,
121+
config=TIDEConfig(exit_threshold=et, min_layers=8))
122+
for _ in range(3):
123+
engine(test_inp.input_ids, attention_mask=test_inp.attention_mask)
124+
torch.cuda.synchronize()
125+
t0 = time.perf_counter()
126+
for _ in range(20):
127+
engine(test_inp.input_ids, attention_mask=test_inp.attention_mask)
128+
torch.cuda.synchronize()
129+
tide_ms = (time.perf_counter() - t0) / 20 * 1000
130+
overhead = (tide_ms - baseline_ms) / baseline_ms * 100
131+
er = engine.last_stats.exit_rate
132+
print(f" TIDE (t={et}): {tide_ms:.2f}ms ({overhead:+.1f}%, exit={er:.0%})")
133+
134+
# Throughput
135+
print(f"\n--- Batch Throughput ---")
136+
for bs in [1, 4, 8]:
137+
batch = (all_prompts * 4)[:bs]
138+
binp = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=256).to(model.device)
139+
ntok = binp.input_ids.numel()
140+
141+
for _ in range(3):
142+
model(**binp)
143+
torch.cuda.synchronize()
144+
t0 = time.perf_counter()
145+
for _ in range(10):
146+
model(**binp)
147+
torch.cuda.synchronize()
148+
base_tps = ntok * 10 / (time.perf_counter() - t0)
149+
150+
engine = TIDERuntime(model, best_rpath,
151+
config=TIDEConfig(exit_threshold=0.85, min_layers=8))
152+
for _ in range(3):
153+
engine(binp.input_ids, attention_mask=binp.attention_mask)
154+
torch.cuda.synchronize()
155+
t0 = time.perf_counter()
156+
for _ in range(10):
157+
engine(binp.input_ids, attention_mask=binp.attention_mask)
158+
torch.cuda.synchronize()
159+
tide_tps = ntok * 10 / (time.perf_counter() - t0)
160+
imp = (tide_tps - base_tps) / base_tps * 100
161+
er = engine.last_stats.exit_rate
162+
print(f" BS={bs}: baseline={base_tps:,.0f} t/s, TIDE={tide_tps:,.0f} t/s ({imp:+.1f}%, exit={er:.0%})")
163+
164+
# Generation with 256 tokens
165+
print(f"\n--- Generation Quality (256 tokens, temp=0) ---")
166+
gen_prompt = "Solve step by step: A store sells apples for $2 each and oranges for $3 each. If I buy a total of 10 fruits and spend $24, how many of each did I buy?"
167+
gen_inp = tokenizer(gen_prompt, return_tensors="pt").to(model.device)
168+
169+
for et in [1.0, 0.85, 0.70, 0.50]:
170+
engine = TIDERuntime(model, best_rpath,
171+
config=TIDEConfig(exit_threshold=et, min_layers=8))
172+
torch.cuda.synchronize()
173+
t0 = time.perf_counter()
174+
out = engine.generate(gen_inp.input_ids, max_new_tokens=256, temperature=0)
175+
torch.cuda.synchronize()
176+
gen_time = time.perf_counter() - t0
177+
178+
text = tokenizer.decode(out[0], skip_special_tokens=True)
179+
stats = engine.last_stats
180+
gen_ids = out[0][gen_inp.input_ids.shape[1]:]
181+
unique = len(set(gen_ids.tolist()))
182+
label = "baseline" if et == 1.0 else f"t={et}"
183+
print(f"\n [{label}] {gen_time:.1f}s, {len(gen_ids)} tokens, exit={stats.exit_rate:.0%}, unique={unique}")
184+
if stats.exits_per_layer:
185+
print(f" Exits: {dict(sorted(stats.exits_per_layer.items()))}")
186+
print(f" Output: {text[:250]}")
187+
188+
print(f"\n{'='*70}")
189+
print("DONE")
190+
print(f"{'='*70}")
191+
192+
193+
@app.local_entrypoint()
194+
def main():
195+
benchmark.remote()

0 commit comments

Comments
 (0)