LLM Reinforcement Learning in Pure CUDA — From Kernels to GRPO
Zero PyTorch. Full RL loop. 1.37x faster than TRL with vLLM.
Quick Start • Kernels • Inference Engine • Training • Benchmarks • Architecture • License • Claude Code
GRPO training on DeepMath-103K (Qwen3-0.6B, same GPU) — 1.37x faster wall-clock, matching reward
A from-scratch implementation of the complete LLM RL pipeline — hand-written CUDA kernels, a vLLM-style inference engine with continuous batching, and GRPO training.
| Layer | What |
|---|---|
| CUDA Kernels | FlashAttention-2, RMSNorm, RoPE, SwiGLU, Embedding, Sampler, AdamW, GRPO loss — all with forward AND backward passes |
| Model | Qwen3-0.6B full forward + backward pass, safetensors weight loading |
| KV Cache | Paged KV cache with block manager (same design as vLLM) |
| Inference Engine | Continuous batching, CUDA graph capture, two-phase scheduling |
| Training | SFT + GRPO with gradient checkpointing, mixed-precision AdamW |
- CUDA Toolkit >= 12.0
- CMake >= 3.18
- GPU: Ampere or newer (sm_80+)
- Python 3.8+ (only for downloading model weights / preparing data)
git clone https://github.com/KJLdefeated/RL.cu.git && cd RL.cu
# Fetch third-party headers and CUTLASS (see "Third-party dependencies" below)
mkdir -p include/third_party third_party
curl -L -o include/third_party/json.hpp \
https://github.com/nlohmann/json/releases/download/v3.11.3/json.hpp
curl -L -o include/third_party/xxhash.h \
https://raw.githubusercontent.com/Cyan4973/xxHash/v0.8.1/xxhash.h
git clone --depth 1 https://github.com/NVIDIA/cutlass.git third_party/cutlass
# Download model weights
pip install huggingface_hub
python scripts/download_model.py Qwen/Qwen3-0.6B model_weights/Qwen3-0.6B
# Build all targets (default: sm_120; override with ARCH=90 for H100, etc.)
make
# Or build a specific target
make build/test_llmengineThese are header-only / source dependencies that aren't checked into the repo (they're in .gitignore). Drop them at the paths shown — that's where #include "third_party/..." and CMakeLists.txt expect to find them.
| Library | Version used | Path | Source |
|---|---|---|---|
| nlohmann/json | 3.11.3 | include/third_party/json.hpp |
https://github.com/nlohmann/json |
| xxHash | 0.8.1 | include/third_party/xxhash.h |
https://github.com/Cyan4973/xxHash |
| NVIDIA CUTLASS | main (header-only includes) | third_party/cutlass/ |
https://github.com/NVIDIA/cutlass |
The curl/git clone commands in the Build block above fetch all three. CUTLASS is large (~250 MB); --depth 1 keeps the clone shallow.
| Command | Description |
|---|---|
make |
Configure + build all targets |
make build/<target> |
Build a specific binary |
make test_rmsnorm |
Build + run a single test |
make tests |
Build + run all kernel & model tests |
make train_grpo |
Build + run GRPO training (default args) |
make train_sft |
Build + run SFT training (default args) |
make clean |
Remove build directory |
make ARCH=90 |
Target a specific GPU arch (e.g. 80=A100, 89=RTX4090, 90=H100) |
make BUILD_TYPE=Debug |
Debug build |
# Correctness tests + throughput benchmark
./build/test_llmengine model_weights/Qwen3-0.6B# Prepare dataset (downloads DeepMath-103K as JSONL)
pip install datasets
python scripts/prepare_data.py --mode grpo-text \
--dataset trl-lib/DeepMath-103K --output data/deepmath-103k.jsonl
# Run with default settings (8 prompts x 8 gens, 100 steps)
./build/train_grpo
# Or customize (see --help for all options)
./build/train_grpo \
--model model_weights/Qwen3-0.6B \
--data data/deepmath-103k.jsonl \
--batch-size 64 --num-gens 8 \
--lr 1e-6 --total-steps 500 \
--save-dir checkpoints/my_run# Prepare dataset
python python_scripts/prepare_data.py --mode sft --output data/sft_train.bin
# Run with default settings
./build/train_sft
# Or customize
./build/train_sft \
--model model_weights/Qwen3-0.6B \
--data data/sft_train.bin \
--batch-size 8 --seq-len 2048 \
--lr 1e-5 --total-steps 5000Every kernel is written with FP16 I/O and FP32 accumulation. Each has a standalone test comparing against a CPU/PyTorch reference.
| Kernel | File | Lines | Key Design |
|---|---|---|---|
| FlashAttention-2 | src/kernels/attention.cu |
1,367 | Forward + backward, GQA, online softmax, tiled WMMA, causal masking |
| RMSNorm | src/kernels/rmsnorm.cu |
211 | Forward + backward, warp-level reduction |
| RoPE | src/kernels/rope.cu |
160 | Forward + backward, NeoX split-half, FP32 sin/cos tables |
| SwiGLU | src/kernels/swiglu.cu |
77 | Forward + backward, fused SiLU * up |
| Embedding | src/kernels/embedding.cu |
86 | Forward gather + backward scatter-add |
| Fused Norm+Linear | src/kernels/fused_norm_linear.cu |
188 | RMSNorm fused with cuBLAS GEMM (saves one HBM round-trip) |
| Sampler | src/kernels/sampler.cu |
147 | Top-k, top-p, temperature, Gumbel-max single-pass |
| AdamW | src/kernels/adamw.cu |
108 | Mixed-precision (FP16 params, FP32 moments), fused update |
| Softmax | src/kernels/softmax.cu |
115 | Numerically stable, warp-level |
| Linear | src/kernels/linear.cu |
103 | cuBLAS FP16 GEMM + backward (dX, dW) |
Build and run any individual kernel test:
make build/test_attention && ./build/test_attention
make build/test_rmsnorm && ./build/test_rmsnorm
# ... etcA vLLM-style engine with continuous batching, paged KV cache, and CUDA graph acceleration.
Features:
- Continuous batching — two-phase decode + prefill per step, new requests start immediately
- Paged KV cache — block-level memory management, no wasted pre-allocation
- CUDA graphs — decode captured at bucket sizes (1, 2, 4, 8, ..., 256)
- Fused projections — QKV and gate+up projections fused into single GEMMs
- Train-inference mismatch = 0 - The GRPO policy ratio needs
log_probs_old(rollout) andlog_probs_new(training). Both come from the sameqwen3_forward()— same kernels, same FP reduction order.
Engine step:
schedule_decode() → prefill new seqs → sample → postprocess
└── continuous: finished slots instantly reused by waiting requests
Standard cross-entropy training with chunked lm_head to control memory.
Full RL training loop:
For each step:
1. Generate G completions per prompt using the inference engine
2. Score with reward function (e.g., boxed-answer matching)
3. Compute GRPO advantages (group-relative normalization)
4. Forward pass with gradient checkpointing
5. Backward pass (recompute activations per layer)
6. AdamW update with gradient accumulation + clipping
Gradient checkpointing saves only per-layer input residuals + FlashAttention LSE, recomputing all other activations during backward.
Sleep/wakeup lifecycle: KV cache pools are freed during training and re-allocated before generation, so the same GPU memory is shared between inference and training phases.
No weight transfer needed: Inference phase and training phase maitain same weight, don't need weight transfer and avoid of inference-training mismatch.
| Batch Size | Throughput (tok/s) |
|---|---|
| 256 | 6,963 |
94% of nano-vllm throughput (7,411 tok/s). See docs/ENGINE.md for the full optimization journey.
RL.cu vs TRL (w/ vLLM) on the same task, same GPU (RTX PRO 6000):
| Metric | RL.cu | TRL (vLLM backend) |
|---|---|---|
| Reward (last 100 steps) | 0.307 | 0.312 |
| Step time | 33.7s | 46.3s |
| Generation throughput | 2,992 tok/s | 2,602 tok/s |
| Wall time (903 steps) | 8.5h | 11.6h |
| Time to reward = 0.3 | 0.8h | 2.9h |
Why RL.cu is 1.37x faster (wall-clock) for the same number of steps:
- 15% faster generation (2,992 vs 2,602 tok/s at matched token counts) — CUDA graphs, fused QKV/gate-up projections, and zero Python overhead eliminate per-step launch costs
- No weight transfer — RL.cu runs inference and training in the same process with shared weights; TRL must sync weights between the training model and vLLM's inference copy every step
- Shorter completions over training — RL.cu's completions shrink from 1,889 → 968 tokens as the model learns concise answers, reducing generation work by ~50%; TRL stays at ~1,840 tokens throughout
RL.cu
├── src/kernels/ # Hand-written CUDA kernels (fwd + bwd)
│ ├── attention.cu # FlashAttention-2 with GQA
│ ├── rmsnorm.cu # RMSNorm
│ ├── rope.cu # Rotary Position Embedding (NeoX)
│ ├── swiglu.cu # SwiGLU activation
│ ├── embedding.cu # Token embedding
│ ├── sampler.cu # Top-k/p sampling
│ ├── adamw.cu # Mixed-precision optimizer
│ └── ...
├── src/model/
│ ├── qwen3.cu # Full forward + backward pass (1,432 lines)
│ └── kv_cache.cu # Paged KV cache operations
├── include/engine/ # Inference engine
│ ├── llm_engine.h # LLMEngine: top-level API
│ ├── scheduler.h # Continuous batching scheduler
│ ├── block_manager.h # Paged KV block allocation
│ └── model_runner.cuh # Model execution + CUDA graphs
├── include/training/ # Training infrastructure
│ ├── GRPO_trainer.h # GRPO training loop
│ ├── SFT_trainer.h # SFT training loop
│ ├── optimizer.h # AdamW with flat buffer
│ └── lr_scheduler.h # Cosine + warmup
├── include/model/
│ ├── tokenizer.h # BPE tokenizer (reads HF tokenizer.json)
│ ├── config.h # Model + engine config
│ └── weights.h # Safetensors loader (mmap)
└── tests/
├── kernels/ # Unit tests for individual CUDA kernels
├── models/ # End-to-end model tests (Qwen3, LLMEngine)
└── training/ # Training loop tests (SFT, GRPO)
| llm.c | vLLM | TRL | RL.cu | |
|---|---|---|---|---|
| Language | C/CUDA | Python + CUDA | Python + PyTorch | C++/CUDA |
| Inference engine | - | Yes | via vLLM | Yes |
| Continuous batching | - | Yes | via vLLM | Yes |
| Paged KV cache | - | Yes | via vLLM | Yes |
| CUDA graphs | - | Yes | - | Yes |
| SFT training | Yes (GPT-2) | - | Yes | Yes (Qwen3) |
| RL training (GRPO) | - | - | Yes | Yes |
| Unified inference + training | N/A | N/A | Bridge needed | Yes |
| Zero train-infer mismatch | N/A | N/A | Requires mitigation | By design |
| Runtime dependencies | None | Python + PyTorch | Python + PyTorch | None |
Tests are organized into three directories:
| Directory | Contents |
|---|---|
tests/kernels/ |
Unit tests for individual CUDA kernels (forward + backward) |
tests/models/ |
End-to-end model tests (Qwen3, LLMEngine integration) |
tests/training/ |
Training loop tests (SFT, GRPO) |
# Run all tests
make tests
# Run a single test
make test_attention # FlashAttention-2 fwd+bwd
make test_llmengine # Full engine (11 integration tests)
# Profile a kernel with Nsight Systems
nsys profile --trace=cuda,cublas --stats=true ./build/test_attentionSee tests/README.md for full profiling guide.
This project is built to work with Claude Code. The repo includes project-level docs and skills so Claude can contribute effectively from the first message.
.claude/CLAUDE.md — 270-line project memory that gives Claude full context:
- GPU architecture (sm_120, CUDA 12.8), model dimensions (Qwen3-0.6B), build system
- Every kernel's API, grid/block config, and known pitfalls
- Scheduler design, KV cache layout, attention kernel gotchas
- All bugs we've fixed and why (so Claude doesn't reintroduce them)
- Coding conventions: FP16 I/O, FP32 accumulation,
#pragma unroll, test patterns
.claude/skills/add_new_kernel/ — Step-by-step skill for implementing new CUDA kernels:
- Header in
include/kernels/*.cuh, source insrc/kernels/*.cu, test intests/test_*.cu - Includes full examples (FlashAttention-2, RMSNorm) showing the kernel → launcher → test pattern
- Warp-level reduction, shared memory tiling, vectorized loads, bounds checking
- CPU reference + tolerance comparison template
# Install Claude Code
npm install -g @anthropic-ai/claude-code
# Start working — Claude already knows the project
cd RL.cu
claude
# Examples of what Claude can do with the project context:
# "Add a flash decoding kernel for paged attention"
# "Why is my attention kernel producing NaN for S > 16?"
# "Optimize the RMSNorm backward kernel with warp shuffle"
# "Add INT8 quantization support for linear layers"Claude understands the full architecture — kernel APIs, memory layouts, known bugs, and conventions — so it can write production-quality CUDA code that fits the project from the start.
Contributions welcome! Some areas where help would be great:
- Flash Decoding (split-K) — the decode attention kernel is currently single-threaded per (seq, head); a proper split-K implementation would give 5-10x speedup
- Multi-GPU support — tensor parallelism for larger models
- More model architectures — Llama, Gemma, etc.
- Speculative decoding — draft model + verification
- Quantization — INT8/INT4 weight quantization
This project is licensed under the MIT License.
