- (π₯ New) [2025/9/29] We released the Jet-Nemotron models and inference code.
- (π₯ New) [2025/9/18] Jet-Nemotron is accepted by NeurIPS 2025! πππ See you at San Diego!
- [2025/8/22] We released the Jet-Nemotron technical report on arXiv.
Jet-Nemotron is a new family of hybrid-architecture language models that surpass state-of-the-art open-source full-attention language models such as Qwen3, Qwen2.5, Gemma3, and Llama3.2, while achieving significant efficiency gainsβup to 53.6Γ speedup in generation throughput on H100 GPUs (256K context length, maximum batch size). It is built upon two core innovations:
- Post Neural Architecture Search, an efficient post-training architecture exploration and adaptation pipeline applicable to arbitrary pre-trained transformer models;
- JetBlock, a novel linear attention block that significantly outperforms previous designs such as Mamba2.
Unlike prior methods that train from scratch to explore new model architectures, PostNAS builds on a pre-trained transformer model while enabling flexible exploration of attention block designs, greatly reducing the cost and risk of developing new language model architectures.
- PostNAS first identifies the optimal placement of full-attention layers, then searches for improved attention block designs.
- In the pre-trained transformer model, not all attention layers contribute equally. PostNAS reveals important attention layers within pre-trained transformer models.
- KV cache size is the most critical factor influencing long-context and long-generation throughput. PostNAS hardware-aware search discovers architectures that deliver similar generation throughput, while having more parameters and achieving better accuracy.
With PostNAS, we introduce the JetBlock β a novel linear attention module that integrates dynamic convolution with hardware-aware architecture search to enhance linear attention, delivering substantial accuracy gains over previous designs while maintaining similar training and inference throughput. Below, we present an apples-to-apples comparison between the Mamba2 Block and the JetBlock, using identical training data and training recipes.
Jet-Nemotron-2B and Jet-Nemotron-4B match or surpass the accuracy of leading efficient language models (e.g., Qwen3) across a comprehensive benchmark suite while running significantly faster β 21Γ and 47Γ faster than Qwen3-1.7B-Base, respectively.
- Setup Environments
- Models
- Generate with Jet-Nemotron
- Evaluation on Benchmarks
- Measure Throughput
- Build Your Own JetBlock
- Contact
- License
- Bibtex
git clone https://github.com/NVlabs/Jet-Nemotron
cd Sana
pip3 install -e .
NOTE: To install flash-attn
properly, you may need to install specific release version or build from source.
(Optional) To support throughput measurement or chunk-prefilling when eval_batch_size > 1, please install a modified version of transformers==4.52.0
:
pip3 install -U transformers@git+https://github.com/jet-ai-projects/transformers.git@jetai
- Jet-Nemotron-2B: jet-ai/Jet-Nemotron-2B
- Jet-Nemotron-4B: jet-ai/Jet-Nemotron-4B
Load the model with
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("jet-ai/Jet-Nemotron-2B",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="cuda")
NOTE: The kernels in Jet-Nemotron currently do not support running on CPUs. You may get unexpected results on CPUs.
To use or contribute to the model definition files in this repo (jetai/modeling/hf
), you can first download or soft-link the model weights and model config to jetai/modeling/hf/
:
hf download jet-ai/Jet-Nemotron-2B --local-dir jetai/modeling/hf --include "*safetensors*" --include "config.json"
Then you can load the model with
model = AutoModelForCausalLM.from_pretrained("jetai/modeling/hf",
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="cuda")
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name_or_path = "jet-ai/Jet-Nemotron-2B"
# For local testing, you can use the following path.
# NOTE: Be sure to download or soft-link the model weights to `jetai/modeling/hf`
# model_name_or_path = "jetai/modeling/hf/"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = model.eval().cuda()
input_str = "Hello, I'm Jet-Nemotron from NVIDIA."
input_ids = tokenizer(input_str, return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_new_tokens=50, do_sample=False)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(output_str)
or
python3 jetai/inference/generate.py --model_name_or_path ${PATH_TO_YOUR_MODEL}
Run evaluation for MMLU, MMLU-pro, BBH, Commonsense, Math, Code, Retrieval, and LongBench Tasks.
bash scripts/eval/2B/mmlu.sh
bash scripts/eval/2B/mmlu_pro.sh
bash scripts/eval/2B/bbh.sh
bash scripts/eval/2B/commonsense.sh
bash scripts/eval/2B/math.sh
bash scripts/eval/2B/code.sh
bash scripts/eval/2B/retrieval.sh
bash scripts/eval/2B/longbench.sh
You can use the first command line argument to specify model_name_or_path
:
bash scripts/eval/2B/mmlu.sh ${PATH_TO_YOUR_MODEL}
NOTE: The evaluation code will use the .parquet
version of social_i_qa
, mathqa
, and longbench
data from our repo because their official repos does not supports loading with datasets >= 4.0.0
.
python3 jetai/inference/measure_throuput.py --model_name_or_path jetai/Jet-Nemotron-2B
python3 jetai/inference/measure_throuput.py --model_name_or_path jetai/Jet-Nemotron-4B --batch_size 64 --prefill_chunk_size 1024
The following code is a minimal example to build your own JetBlock.
import torch
from jetai.modeling.hf.jet_block import (
JetBlock,
JetBlockConfig
)
jet_block_config = JetBlockConfig(
expand_v=2.0,
num_heads=6,
head_dim=256,
conv_size=4,
)
jet_block = JetBlock(
hidden_size=1536,
initializer_range=0.02,
jet_block_config=jet_block_config,
).cuda().to(torch.bfloat16)
hidden_states = torch.randn(16, 4096, 1536).cuda().to(torch.bfloat16)
hidden_states, _ = jet_block(
hidden_states=hidden_states,
)
print(hidden_states)
@article{gu2025jet,
title={Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search},
author={Gu, Yuxian and Hu, Qinghao and Yang, Shang and Xi, Haocheng and Chen, Junyu and Han, Song and Cai, Han},
journal={arXiv preprint arXiv:2508.15884},
year={2025}
}