Skip to content

Commit 8eec8cc

Browse files
committed
blackwell + new models
1 parent 553ae47 commit 8eec8cc

4 files changed

Lines changed: 298 additions & 100 deletions

File tree

README.md

Lines changed: 79 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ TIDE auto-probes your model's architecture. No adapter code needed.
108108

109109
| Model Family | Examples | Status |
110110
|---|---|---|
111-
| LLaMA | LLaMA 2, LLaMA 3, CodeLlama, TinyLlama | Tested |
112-
| Mistral | Mistral 7B, Mixtral | Tested |
113-
| Qwen | Qwen 2.5 series | Tested |
111+
| LLaMA | LLaMA 3.3, LLaMA 4 Scout/Maverick | Benchmarked |
112+
| DeepSeek | DeepSeek R1, R1 Distill 8B/32B/70B | Benchmarked |
113+
| Qwen | Qwen3 8B/32B, Qwen 2.5 | Benchmarked |
114+
| Mistral | Mistral Small 3.1, Mixtral | Supported |
115+
| Gemma | Gemma 3 12B/27B | Supported |
114116
| GPT-2 | GPT-2, DistilGPT-2 | Tested |
115117
| GPT-NeoX | Pythia, GPT-NeoX-20B | Supported |
116-
| Phi | Phi-2, Phi-3 | Supported |
118+
| Phi | Phi-3, Phi-4 | Supported |
117119
| Falcon | Falcon 7B/40B | Supported |
118120
| OPT | OPT-1.3B through OPT-30B | Supported |
119121
| **Anything else** | Any `AutoModelForCausalLM` | Auto-probed |
@@ -130,108 +132,119 @@ engine = TIDE.TIDE(model, "router.pt") # UniversalAdapter handles it
130132

131133
GPU architecture is auto-detected at install time.
132134

133-
| GPU | Status | Notes |
135+
| GPU | Arch | Status |
134136
|---|---|---|
135-
| V100 | Supported | sm_70 |
136-
| T4 | Supported | sm_75, great for cost-efficient inference |
137-
| A100 | Supported | sm_80 |
138-
| A10G | Tested in CI | sm_86, Modal/AWS default |
139-
| L4 | Supported | sm_89 |
140-
| H100 | Supported | sm_90 |
137+
| V100 | sm_70 | Supported |
138+
| T4 | sm_75 | Supported |
139+
| A100 | sm_80 | Benchmarked |
140+
| A10G | sm_86 | Tested in CI |
141+
| L4 / L40S | sm_89 | Supported |
142+
| H100 / H200 | sm_90 | Supported |
143+
| B100 / B200 | sm_100 | Supported |
144+
| GB200 / GB300 | sm_120 | Supported (PTX fallback) |
141145

142146
Override: `TORCH_CUDA_ARCH_LIST="8.6" pip install .`
143147

144148
No GPU? TIDE works in pure PyTorch (CPU fallback, no CUDA kernels needed).
145149

146150
## Benchmark Results
147151

148-
Tested on **LLaMA 3.1 8B Instruct** (32 layers, 4096 hidden) on NVIDIA A100-SXM4-40GB.
149-
Calibrated with 2000 WikiText samples. CUDA kernels compiled for sm_80.
152+
All benchmarks on **NVIDIA A100-SXM4-40GB**, bf16 precision, 2000 WikiText calibration samples.
153+
16 real text prompts (science, code, history, economics).
150154

151155
### Prefill Exit Rates
152156

153-
16 real text prompts (science, code, history), evaluated at different thresholds:
154-
155157
```
156-
Threshold Exit Rate Where Exits Happen
157-
========= ========= ==================
158-
0.95 98.9% L11: 16 tokens, L31: 158 tokens
159-
0.90 100.0% L11: 16 tokens, L31: 160 tokens
160-
0.85 100.0% L11: 16 tokens, L31: 160 tokens
161-
0.70 100.0% L11: 16 tokens, L31: 160 tokens
162-
0.50 100.0% L11: 16 tokens, L31: 160 tokens
158+
Model Layers Threshold Exit Rate Where Exits Happen
159+
========================== ====== ========= ========= ==========================
160+
Qwen3 8B 36 0.95 100.0% L35: 155 tokens
161+
Qwen3 8B 36 0.85 100.0% L35: 155 tokens
162+
Qwen3 8B 36 0.50 100.0% L11:11 L23:5 L35:139
163+
DeepSeek R1 Distill 8B 32 0.95 100.0% L31: 176 tokens
164+
DeepSeek R1 Distill 8B 32 0.85 100.0% L11:16 L31:160
165+
DeepSeek R1 Distill 8B 32 0.50 100.0% L11:16 L31:160
163166
```
164167

165-
100% of tokens converge by Layer 31 (the last checkpoint before the final layer).
166-
9% of tokens converge as early as Layer 11 only 1/3 of the way through the model.
168+
100% of tokens converge by the last checkpoint. At lower thresholds, earlier exits
169+
appear — up to 10% of tokens exit at Layer 11, only 1/3 of the way through the model.
167170

168171
### Prefill Latency
169172

170173
Single prompt, 20 runs averaged:
171174

172175
```
173-
Configuration Latency vs Baseline
174-
====================== ======= ===========
175-
Baseline (no TIDE) 54.04ms --
176-
TIDE (threshold=0.95) 50.94ms -5.7%
177-
TIDE (threshold=0.85) 50.52ms -6.5%
178-
TIDE (threshold=0.50) 50.21ms -7.1%
176+
Model Baseline TIDE (t=0.85) Change
177+
========================== ========== ============= ======
178+
Qwen3 8B (36 layers) 46.82ms 44.14ms -5.7%
179+
DeepSeek R1 Distill 8B 31.66ms 31.89ms +0.7%
179180
```
180181

181-
TIDE is **faster than baseline** even in frozen-token mode (all layers still run)
182-
because the router evaluation + early output selection avoids redundant final-layer
183-
normalization for exited tokens.
184-
185182
### Batch Throughput
186183

187184
```
188-
Batch Size Baseline (tok/s) TIDE (tok/s) Improvement
189-
========== ================ ============ ===========
190-
1 231 252 +9.1%
191-
4 834 902 +8.2%
192-
8 1,618 1,773 +9.6%
185+
Model BS Baseline (tok/s) TIDE (tok/s) Change
186+
========================== == ================ ============ ======
187+
Qwen3 8B 1 258 271 +5.0%
188+
Qwen3 8B 4 923 961 +4.2%
189+
Qwen3 8B 8 1,781 1,926 +8.1%
190+
DeepSeek R1 Distill 8B 1 403 403 +0.0%
191+
DeepSeek R1 Distill 8B 8 2,997 2,833 -5.5%
193192
```
194193

194+
Qwen3 (36 layers) shows consistent improvement. DeepSeek R1 Distill (32 layers,
195+
already optimized via distillation) has minimal headroom.
196+
195197
### Generation Quality
196198

197-
100 tokens generated with `temperature=0` on the same prompt:
199+
100 tokens, `temperature=0`, same prompt across thresholds:
198200

199201
```
200-
Threshold Exit Rate Output
201-
========= ========= =============================================
202-
1.00 (off) 0% "Backpropagation is a fundamental algorithm
203-
in neural networks that enables them to learn
204-
from data. Here's a step-by-step guide on
205-
how it works: 1. Forward pass: The input..."
206-
207-
0.85 95% "Backpropagation is a fundamental algorithm
208-
in neural networks that enables them to learn
209-
from data. In this article, we'll break down
210-
the process of how neural networks learn..."
211-
212-
0.50 96% (same as 0.85 — stable)
202+
Model Threshold Exit Rate Output
203+
===================== ========= ========= ==============================
204+
DeepSeek R1 Distill 8B 1.0 (off) 0% "Transformers are a type of
205+
neural network architecture
206+
that uses self-attention
207+
mechanisms to capture long-
208+
range dependencies..."
209+
210+
DeepSeek R1 Distill 8B 0.85 100% "Transformers are neural
211+
networks that use self-
212+
attention mechanisms to
213+
process sequential data.
214+
They are particularly
215+
effective for tasks like
216+
machine translation..."
217+
218+
Qwen3 8B 1.0 (off) 0% "...the basic principles,
219+
the role of the core, the
220+
function of the windings,
221+
and the importance of the
222+
magnetic field..."
223+
224+
Qwen3 8B 0.85 100% Quality degrades at 100%
225+
exit rate (10 unique tokens).
226+
Use threshold >= 0.90 for
227+
Qwen3 to preserve quality.
213228
```
214229

215-
95% of decode tokens exit at Layer 31 — the output diverges slightly in phrasing
216-
("Here's a step-by-step guide" vs "In this article, we'll break down") but
217-
remains equally coherent and factually correct.
230+
**Key finding**: DeepSeek R1 Distill maintains quality at 100% exit rate.
231+
Qwen3 is more sensitive — use a higher threshold (0.90+) to preserve output quality.
218232

219233
### Convergence Analysis
220234

221-
Layer-by-layer convergence (cosine similarity > 0.98 with final layer):
235+
Calibrated on 2000 WikiText samples, cosine similarity > 0.98 with final layer:
222236

223237
```
224-
Model Layers Convergence per Checkpoint Layer
225-
================= ====== ===========================================
226-
LLaMA 3.1 8B 32 L3:0% L7:0% L11:0% L15:0% L19:0% L23:0%
227-
L27:0% L31:100%
228-
GPT-2 (124M) 12 L3:0% L7:0% L11:100%
229-
TinyLlama (1.1B) 22 L3:0% L7:0% L11:0% L15:0% L19:0%
238+
Model Layers Convergence per Checkpoint
239+
======================= ====== =======================================
240+
Qwen3 8B 36 L3-L31: 0% L35: 100%
241+
DeepSeek R1 Distill 8B 32 L3-L27: 0% L31: 100%
242+
LLaMA 3.1 8B 32 L3-L27: 0% L31: 100%
243+
GPT-2 (124M) 12 L3: 0% L7: 0% L11: 100%
230244
```
231245

232-
The convergence threshold (0.98) is strict — most tokens converge at the last
233-
checkpoint. With a lower convergence threshold during calibration, earlier exits
234-
become available.
246+
The strict threshold (0.98) means most tokens converge at the penultimate checkpoint.
247+
Lower `convergence_threshold` during calibration (e.g., 0.95) enables earlier exits.
235248

236249
## Tuning the Threshold
237250

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,58 @@
11
# Model registry for TIDE benchmarks
2-
# Organized by phase (implementation priority)
32

43
phase_1:
5-
# Small/medium models for initial validation
6-
- name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7-
short: "tinyllama-1.1b"
8-
gpu: "A10G"
9-
dtype: "float16"
10-
- name: "meta-llama/Llama-3.1-8B-Instruct"
11-
short: "llama3.1-8b"
4+
# Small/medium models for validation
5+
- name: "meta-llama/Llama-4-Scout-17B-16E-Instruct"
6+
short: "llama4-scout-17b"
127
gpu: "A100"
13-
dtype: "float16"
14-
- name: "mistralai/Mistral-7B-Instruct-v0.3"
15-
short: "mistral-7b"
8+
dtype: "bfloat16"
9+
- name: "meta-llama/Llama-3.3-70B-Instruct"
10+
short: "llama3.3-70b"
11+
gpu: "H100:2"
12+
dtype: "bfloat16"
13+
- name: "Qwen/Qwen3-8B"
14+
short: "qwen3-8b"
15+
gpu: "A100"
16+
dtype: "bfloat16"
17+
- name: "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
18+
short: "mistral-small-3.1"
1619
gpu: "A100"
17-
dtype: "float16"
18-
- name: "Qwen/Qwen2.5-7B-Instruct"
19-
short: "qwen2.5-7b"
20+
dtype: "bfloat16"
21+
- name: "google/gemma-3-12b-it"
22+
short: "gemma3-12b"
2023
gpu: "A100"
21-
dtype: "float16"
24+
dtype: "bfloat16"
2225

2326
phase_2:
24-
# Medium models
25-
- name: "meta-llama/Llama-3.1-70B-Instruct"
26-
short: "llama3.1-70b"
27-
gpu: "H100:2"
28-
dtype: "float16"
29-
- name: "Qwen/Qwen2.5-72B-Instruct"
30-
short: "qwen2.5-72b"
27+
# Large models
28+
- name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
29+
short: "llama4-maverick-17b"
3130
gpu: "H100:2"
32-
dtype: "float16"
31+
dtype: "bfloat16"
32+
- name: "Qwen/Qwen3-32B"
33+
short: "qwen3-32b"
34+
gpu: "H100"
35+
dtype: "bfloat16"
36+
- name: "google/gemma-3-27b-it"
37+
short: "gemma3-27b"
38+
gpu: "H100"
39+
dtype: "bfloat16"
3340

3441
phase_3:
35-
# Reasoning models (key TIDE targets)
42+
# Reasoning models (key TIDE targets — long decode, many easy tokens)
3643
- name: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
3744
short: "r1-distill-8b"
3845
gpu: "A100"
39-
dtype: "float16"
46+
dtype: "bfloat16"
4047
- name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
4148
short: "r1-distill-32b"
4249
gpu: "H100"
43-
dtype: "float16"
50+
dtype: "bfloat16"
4451
- name: "Qwen/QwQ-32B"
4552
short: "qwq-32b"
4653
gpu: "H100"
47-
dtype: "float16"
54+
dtype: "bfloat16"
4855
- name: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
4956
short: "r1-distill-70b"
5057
gpu: "H100:2"
51-
dtype: "float16"
52-
- name: "deepseek-ai/DeepSeek-R1"
53-
short: "r1-671b"
54-
gpu: "H100:4"
55-
dtype: "float16"
58+
dtype: "bfloat16"

0 commit comments

Comments
 (0)