This repository accompanies the paper TokenButler: Token Importance is Predictable. It contains the code and reproduction
scripts for the long-context decoding experiments in the paper. the Llama-3.1-8B-Instruct predictor is
released on Hugging Face at
Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler
and is fetched by scripts/download_weights.sh (see Installation).
TokenButler is a lightweight (≈1% parameter overhead) query-aware predictor that identifies the small subset of KV-cache tokens needed by each decoding step. By preserving the full KV cache and selecting tokens at fine granularity, TokenButler avoids the failure modes of permanent eviction (H2O, SnapKV) and coarse paging (Quest), and matches or surpasses prior KV-sparsity methods on long-context benchmarks. With prediction-interval and neighbor fetching, the predictor cost is amortized to deliver up to ≈1.6× on-GPU speedup over Dense Attention while staying within ≈1.1% accuracy of every-step prediction, and up to ≈7.6× speedup over Dense under CPU offloading at 1M context.
| Paper artifact | Script |
|---|---|
| Table 4 — RULER 64K accuracy on Llama-3.1-8B-Instruct (budget 8K, 10 tasks) | examples/tokenbutler_accuracy.sh |
| Table 6 — RULER 64K predict-interval ablation (budget 8K, i ∈ {1,2,4,8,16} with neighbor fetch for i>1) | examples/tokenbutler_accuracy_intervals.sh |
| Figure 4 — decode latency vs context (GPU 32K–128K + CPU offload 256K–1M; Dense, TokenButler intervals, Oracle) | examples/tokenbutler_efficiency.sh then test/plot_decoding_efficiency.py |
| Per-operation timing breakdown (QKV / RoPE / Attention / Predictor / Scoring / Selection / KV gather) | test/benchmark_combined_figure.py then test/plot_timing_breakdown.py |
| All of the above end-to-end | examples/reproduce_all.sh |
Per-stage instructions are in Running stages individually; the one-shot path is in One-shot reproduction.
- Python 3.11
- CUDA-capable GPU. The paper used a single NVIDIA A6000 (48 GB). 64K-context RULER evaluation comfortably fits in 48 GB; 128K-context efficiency runs need similar memory; ≥256K runs use CPU offloading and require sufficient host RAM (≈64 GB at 1M).
- Access to
meta-llama/Meta-Llama-3.1-8B-Instructon Hugging Face (gated;huggingface-cli loginonce, then it is cached). - For the CUDA kernels: NVCC matching your PyTorch build.
# 1. Create the environment (uv recommended, but plain venv also works)
uv venv --python 3.11
source .venv/bin/activate
uv pip install --upgrade pip
# 2. Install Python dependencies
uv pip install -r requirements.txt
uv pip install flash-attn==2.7.4.post1 --no-build-isolation
# 3. Pull CUTLASS (required to build the CUDA kernels)
mkdir -p 3rdparty
git clone https://github.com/NVIDIA/cutlass.git 3rdparty/cutlass
# 4. Build the CUDA kernels in-place
python setup.py build_ext --inplace
# 5. Download the Llama-3.1-8B-Instruct TokenButler predictor weights
bash scripts/download_weights.shThe download step fetches L3_8Bi_d16_i512_pf4.pt from Hugging Face
and places it at the repo root, where every script expects it. The predictor uses a producer every G=4 layers,
interaction dimension d′=16, and a two-layer MLP with hidden size 512
(matching Appendix A in the paper).
To run the entire reproducible portion of the paper end-to-end:
bash examples/reproduce_all.shStages and outputs:
| # | Stage | Output |
|---|---|---|
| 1 | Build RULER benchmark data | data/ruler/data/llama-3/{32768,65536,131072,262144}/... |
| 2 | RULER accuracy at 64K, budget 8K | archive/Meta-Llama-3.1-8B-Instruct/ruler/*.jsonl |
| 3 | RULER accuracy with predict-interval i ∈ {1,2,4,8,16}, budget 8K | archive/Meta-Llama-3.1-8B-Instruct/ruler/*.jsonl |
| 4 | Decoding-efficiency sweep, 32K → 1M | test/output/efficiency_budget8K_1M/decoding_time_vs_context.csv |
| 5 | Per-operation timing breakdown | paper_plots/combined_timing.csv, paper_plots/combined_timing_figure.{pdf,png} |
| 6 | Render figures | paper_plots/decoding_performance.pdf (GPU), paper_plots/decoding_performance_cpu.pdf (CPU offload), paper_plots/timing_breakdown.pdf |
Each accuracy JSONL stores per-sample predictions; the final line records
avg_score, the metric reported in the paper. The efficiency CSV is appended
to incrementally so a re-run picks up where it left off — already-completed
(label, context_length) pairs are skipped.
bash examples/build_ruler_data.shGenerates the synthetic RULER samples (NIAH variants, VT, FWE, QA-1, QA-2) at
sequence lengths 32K, 64K, 128K, and 256K under
data/ruler/data/llama-3/. This is required before any accuracy run.
bash examples/tokenbutler_accuracy.shEvaluates Llama-3.1-8B-Instruct + TokenButler at 64K context with sparse budget
8192 across RULER tasks (niah_single_{1,2}, niah_multikey_{1,2},
niah_multiquery, niah_multivalue, vt, fwe, qa_1, qa_2). Results
land at
archive/Meta-Llama-3.1-8B-Instruct/ruler/<task>_65536_TokenButler_b8192_*.jsonl.
bash examples/tokenbutler_accuracy_intervals.shSweeps the prediction interval i ∈ {1, 2, 4, 8, 16} at 64K context with sparse budget 8192. i=1 runs the predictor every step (the no-amortization baseline); i>1 enables neighbor fetching to amortize the predictor across N steps with minimal accuracy loss.
A finer-grained variant of this ablation, including paired efficiency measurements and a combined accuracy/latency plot at the same 64K / sparse-budget-8192 settings (and an extended interval i ∈ {1, 2, 4, 8, 16, 32}), is also available:
python test/ablation_predict_interval.py # full run, all 11 RULER tasks at 64K
python test/ablation_predict_interval.py --quick # 15 samples, 4 datasets — sanity checkbash examples/tokenbutler_efficiency.shMeasures per-token decode latency on Llama-3.1-8B-Instruct at sparse budget 8192, across:
- GPU: contexts 32K, 64K, 128K (gen_length=1024)
- CPU offload: contexts 256K, 512K, 1M (gen_length=128)
For each context, it benchmarks Dense, TokenButler with i ∈ {1, 2, 4, 8, 16}
(plus neighbor fetch for i>1), and Oracle baselines (random / contiguous /
random-with-i=16) as zero-overhead lower bounds. Each configuration runs in an
isolated subprocess — see test/run_missing_configs.py — so OOMs or timeouts
do not poison the rest of the sweep. Results stream into
test/output/efficiency_budget8K_1M/decoding_time_vs_context.csv.
To plot the GPU and CPU-offload curves separately:
python test/plot_decoding_efficiency.py \
--csv test/output/efficiency_budget8K_1M/decoding_time_vs_context.csv \
--output-dir paper_plotspython test/benchmark_combined_figure.py --predictor-path L3_8Bi_d16_i512_pf4.pt
python test/plot_timing_breakdown.py --csv paper_plots/combined_timing.csv \
--output paper_plots/timing_breakdown.pdfbenchmark_combined_figure.py profiles every component (QKV proj, RoPE,
flash-attn, MLP, predictor forward, score computation, top-K selection, key/
value gather) at multiple K values, writes a tidy CSV, and emits a combined
figure. plot_timing_breakdown.py re-renders the stacked-area figure used in
the paper from that CSV.
For quick experimentation on one dataset:
python test/eval_acc.py \
--model_name meta-llama/Meta-Llama-3.1-8B-Instruct \
--method TokenButler \
--datalen 65536 \
--dataset_name ruler/niah_single_1 \
--sparse_budget 8192 \
--chunk_size 8 \
--rank 160 \
--predictor_path L3_8Bi_d16_i512_pf4.pt \
--dDash 16 \
--producer_frequency 4 \
--tokenbutler_intermediate_dim 512Useful flags for the predict-interval optimization:
--predict_interval N: invoke the predictor every N decode steps (default 1).--enable_neighbor_fetch: whenpredict_interval > 1, also fetch neighbors of selected tokens to mitigate stale selections.
.
├── models/
│ ├── tokenbutler_predictor.py # the lightweight predictor (paper §3.1)
│ ├── kv_cache_tokenbutler.py # GPU-resident sparse KV cache
│ ├── kv_cache_tokenbutler_cpu.py # CPU-offload sparse KV cache
│ ├── kv_cache_oracle{,_cpu}.py # Oracle baselines for the efficiency study
│ ├── kv_cache_dsa.py # DeepSeek-V3.2-style indexer baseline
│ ├── kv_cache.py # Dense + ShadowKV references
│ ├── llama.py / glm.py / qwen.py / phi3.py
│ └── base.py # LLM wrapper + cache factory
├── kernels/
│ ├── int8_score_fused.py
│ └── *.cu # CUTLASS-backed prefill kernels
├── data/
│ ├── dataset.py # unified dataset interface
│ ├── metrics.py
│ └── ruler/ # RULER generation (adapted from RULER repo)
├── test/
│ ├── eval_acc.py # accuracy evaluation entry point
│ ├── evaluator.py
│ ├── benchmark_tokenbutler.py # decode-latency micro-benchmark
│ ├── run_missing_configs.py # efficiency-sweep orchestrator (OOM-safe)
│ ├── benchmark_combined_figure.py# per-operation timing breakdown
│ ├── plot_decoding_efficiency.py # Figure 4 renderer
│ ├── plot_timing_breakdown.py # timing-breakdown renderer
│ ├── ablation_predict_interval.py# Table-6-style ablation with paired plot
│ └── benchmark_accuracy.py # automated comparison driver (optional)
├── examples/
│ ├── reproduce_all.sh
│ ├── build_ruler_data.sh
│ ├── tokenbutler_accuracy.sh
│ ├── tokenbutler_accuracy_intervals.sh
│ └── tokenbutler_efficiency.sh
├── scripts/
│ └── download_weights.sh
├── 3rdparty/cutlass/ # populated by `git clone` step above
├── requirements.txt
├── setup.py # builds the CUTLASS-backed CUDA kernels
└── TokenButler.pdf
The released predictor (downloaded from
Alwahsh/Meta-Llama-3.1-8B-Instruct-Butler)
targets Llama-3.1-8B-Instruct and reproduces the Llama rows of Tables 4
and 6 and the entirety of Figure 4. Architecture support exists in
models/{glm,qwen,phi3}.py for evaluating with predictors trained on those
backbones.
If you use TokenButler in your work, please cite the paper:
@misc{akhauri2025tokenbutlertokenimportancepredictable,
title={TokenButler: Token Importance is Predictable},
author={Yash Akhauri and Ahmed F AbouElhamayed and Yifei Gao and Chi-Chih Chang and Nilesh Jain and Mohamed S. Abdelfattah},
year={2025},
eprint={2503.07518},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2503.07518},
}Apache 2.0. See LICENSE.
This repository builds directly on:
- xKV (Chang et al., 2025) — provided the starting point for this codebase.
- RULER (Hsieh et al., 2024) — synthetic long-context benchmark; the
generation code under
data/ruler/is adapted from the original repository.
The TokenButler-specific contributions in this release — the predictor module, the GPU and CPU-offload sparse KV caches, and the prediction-interval / neighbor-fetching scheme — are described in the accompanying paper.