Skip to content

Feat/chunked prefill#1086

Draft
dacorvo wants to merge 26 commits intofeat/chunked-prefill-basefrom
feat/chunked-prefill
Draft

Feat/chunked prefill#1086
dacorvo wants to merge 26 commits intofeat/chunked-prefill-basefrom
feat/chunked-prefill

Conversation

@dacorvo
Copy link
Collaborator

@dacorvo dacorvo commented Mar 9, 2026

What does this PR do?

Base pull-request to review claude code without triggering the CI

dacorvo and others added 22 commits March 4, 2026 14:34
The first change is to defer the annotation evaluations to class
instantiations by importing annotations, and the second
to only import sentence_transformers symbols when the package is
installed.
Loading a compiled model onto NeuronCores permanently sets the Neuron
runtime's global communicator world_size for the lifetime of the process.
Tests that use build_module (tp_degree=1) poison the runtime so that
later tests requiring tp_degree=2 fail with "World size of neff N is
greater than world size of global communicator M".

Add @subprocess_test decorator that re-invokes each test via pytest in a
fresh subprocess, preventing NRT state from leaking into the parent
session.

Also change _save_checkpoint() to use the gloo backend instead of xla,
avoiding premature NRT initialization during checkpoint creation.

Add conftest.py with pytest_collection_modifyitems hook to reorder
@subprocess_test tests before session-scoped fixtures that hold
NeuronCores.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When running `python tests/fixtures/llm/export_models.py` directly, show
a compact live display with a scrolling 10-line log window, the current
model configuration being compiled, and an overall progress bar.

Falls back to the plain loop if rich is not installed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- `--list`: print configuration names with model IDs and exit
- positional glob pattern filters which configs to export
  (e.g. `'gemma*'`, `'*-1x8192'`)
- no-match prints available names and exits with error

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ron XLA

Move mask handling inside manual_softmax: masked positions are excluded
from max/exp/sum and receive zero probability. This avoids feeding
extreme fill values (finfo.min) through exp() which produces NaN on
Neuron XLA's vectorized implementation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace finfo(bf16).min with -1e9 as the masked-fill value in
scaled_qk(): Neuron XLA's exp() produces NaN for inputs near the
dtype minimum, while exp(-1e9) safely underflows to 0.0 in f32.
Verify that manual_softmax produces equivalent results on NeuronCores
versus CPU, using dimensions representative of Llama-3.2-1B (8 heads,
head_dim=64) and Gemma3-1B (4 heads, head_dim=256, sliding_window=512).

Tests both the masked path (with boolean prior_mask) and the unmasked
legacy path, each parametrized over both model configurations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds support for splitting long prompts into fixed-size chunks processed
sequentially through the KV-cache scatter path, eliminating the need for
a large context-encoding compilation graph.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When a chunk has fewer real tokens than chunk_size, the remaining
positions are padded with repeated position_ids. The old lower-
triangular active_mask let padded tokens attend to more context
than the real last token, producing different KV values. With
torch.scatter writing duplicate indices (all padded tokens → same
cache position), the non-deterministic write order on Neuron
corrupted the KV cache.

Derive an is_real column mask from position_ids (strictly increasing
→ real token) and AND it with the causal mask so padded tokens attend
to exactly the same set as the real last token.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two tests sharing one module-scoped fixture (compiles both models once):

- test_chunked_prefill_generates_same_tokens[short_context]: 32-token
  prompt (< CHUNK_SIZE=64), exercises the repeat-last-token padding path
  in NxDDecoderWrapperForCausalLM.
- test_chunked_prefill_generates_same_tokens[long_context]: 256-token
  prompt (= 4 × CHUNK_SIZE), exercises multi-chunk KV accumulation.

Both parametrizations assert exact greedy-token equality between chunked
prefill and standard context encoding. On failure the test decodes both
outputs via AutoTokenizer so the divergence is human-readable.

- test_chunked_prefill_graph_structure: asserts the chunked model has
  chunked_prefill_model (not context_encoding_model) alongside
  token_generation_model.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When the loaded model has prefill_chunk_size > 0, the runner now processes
new prompts chunk-by-chunk instead of a single full-sequence context-encoding
pass.  Each sequence is processed independently, one at a time — processing
one sequence at a time is not only simpler but also faster and consumes less
device memory than batching multiple sequences (benchmarked on trn1.32xlarge
with Llama 3.1 8B).

model_loader.py — OptimumNeuronModelForCausalLM:
- Add prefill_chunk_vllm(input_ids, position_ids, seq_ids, sampling_params)
  which delegates to NxDModelForCausalLM.prefill_chunk() and squeezes the
  [1, 1, vocab] output to [1, vocab] for the sampler.

runner.py — OptimumNeuronModelRunnerForCausalLM:
- execute_model(): when prefill_chunk_size > 0 and there are new prompt
  requests, dispatch to _execute_chunked_prefill() instead of _prepare_prompt()
  + get_next_tokens().
- _execute_chunked_prefill(): registers each new request in the batch, then
  processes each sequence independently one at a time.  For each sequence,
  iterates over ceil(seq_len / chunk_size) rounds, building [1, chunk_size]
  input_ids / position_ids tensors (padded with repeat-last-token when
  shorter than chunk_size).  Calls prefill_chunk_vllm() each round and
  CPU-samples from the final round's logits via NeuronSampler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
vLLM V1 defaults enable_chunked_prefill=True which sets
max_num_batched_tokens to 2048. The engine core later disables chunked
prefill (empty kv_cache_groups — Neuron manages KV cache internally),
leaving a 2048-token budget with no chunking. Any prompt >2048 tokens
is permanently stuck in the scheduler waiting queue.

Override to max_model_len: both standard and chunked prefill process
one sequence at a time, so the budget should be at least max_model_len.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds a shared session fixture and two test modules that verify the
chunked-prefill vLLM integration end-to-end.

fixtures/llm/export_models.py:
- Add CHUNKED_PREFILL_MODEL_ID / _SEQUENCE_LENGTH / _CHUNK_SIZE / _BATCH_SIZE
  constants.
- Add chunked_prefill_llm_config session fixture: exports and hub-caches both
  a standard (context-encoding) and a chunked-prefill Llama-3.2-1B model,
  each with batch_size=2, sequence_length=4096, chunk_size=512,
  on_device_sampling=False, fused_qkv=True.

tests/vllm/engine/test_vllm_engine_chunked_prefill.py:
- Module-scoped fixture loads both models as LLM instances.
- test_chunked_prefill_engine_generates_same_tokens[short]: prompt =
  chunk_size // 2 tokens, exercises the padding path in
  _execute_chunked_prefill.
- test_chunked_prefill_engine_generates_same_tokens[long]: prompt =
  chunk_size * 2 tokens, exercises two-chunk KV accumulation.
- test_chunked_prefill_engine_batch: sends [short, long] in one call,
  exercising the "repeat last real token for exhausted sequence" path
  in the second chunk round.

tests/vllm/service/test_vllm_service_chunked_prefill.py:
- Module-scoped fixtures launch the standard and chunked models as
  separate vLLM HTTP services via vllm_launcher.
- test_vllm_service_chunked_prefill_generation: verifies the chunked
  service starts and returns the expected number of tokens.
- test_vllm_service_std_vs_chunked_prefill: verifies that greedy text
  output is identical between the two services, confirming the runner's
  _execute_chunked_prefill path matches the standard context-encoding
  path end-to-end through the HTTP API.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both models exported with identical params (BS=32, SL=4096, TP=8,
on_device_sampling=false) — std from main (0.4.5.dev2), chunked from
feat branch (chunk_size=1024). Benchmarked with 32 concurrent users.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
On-device sampling (ODS) halves inter-token latency by running the
sampler on NeuronCores instead of CPU, but it was previously disabled
for all chunked prefill models because the prefill graph must return
logits for CPU sampling.

Hybrid ODS solves this by compiling two graphs with different sampling
modes:
- Token generation graph: ODS enabled (returns sampled tokens)
- Chunked prefill graph: ODS disabled (returns logits for CPU sampling)

Each graph builder already creates its own NxDDecoder instance with a
separate neuron_config, so the change is minimal:
- _create_chunked_prefill_config() forces on_device_sampling=False
- The incompatibility check is removed from NxDNeuronConfig
- The vLLM runner creates a CPU sampler when chunked prefill is active
- ODS output is unsqueezed to [batch, 1] to match CPU sampler shape

Benchmark results (Llama 3.1 8B, BS=32, SL=4096, TP=8, 32 users):
  Config A (std CE, ODS=true):   426 tok/s, 59ms ITL
  Config C (chunked, ODS=false): 190 tok/s, 119ms ITL
  Config D (chunked, hybrid):    440 tok/s, 56ms ITL  (+132% vs C)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add benchmark results for the full 4-way comparison matrix:
  A: std CE, ODS=true  (426 tok/s, 59ms ITL)
  D: chunked, hybrid   (440 tok/s, 56ms ITL)

Config D (hybrid ODS) matches or exceeds the production baseline (A)
while providing chunked prefill's memory advantages.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…euse

When processing multiple sequences sequentially, the NxD runtime reuses
output buffers between calls. Without .clone(), previously stored logits
are silently overwritten, causing token shift corruption in batched output.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Collaborator Author

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but I would like some of the commits to be amended to make it easier to maintain.

# scatter is a no-op overwrite with the same value — no corruption.
last_input_id = input_ids[:, -1:] # [batch, 1]
last_pos = position_ids[:, -1:] # [batch, 1]
input_ids = torch.cat([input_ids, last_input_id.expand(-1, pad_length)], dim=-1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should reuse pad_to_max_context_length for consistency

f"Adjusting num_key_value_heads from {config.num_key_value_heads} to {num_key_value_heads} for TP {neuron_config.tp_degree}."
)
self.kv_mgr = KVCacheManager(config, neuron_config, actual_num_key_value_heads=num_key_value_heads)
self.is_chunked_prefill_graph = neuron_config.prefill_chunk_size > 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not use a flag set at init. All other graphs are identified when relevant based on the neuron_config or input shapes

sampling_params (torch.FloatTensor): Sampling parameters.
"""
is_for_context_encoding = self._is_context_encoding(input_ids)
is_for_speculation = self._is_for_speculation(input_ids)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should insert a flag for chunked_prefill. Ideally it should not be incompatible with is_for_context_encoding


hidden_size = hidden_states.shape[-1]
if is_for_context_encoding:
if is_for_context_encoding or self.is_chunked_prefill_graph:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have is_for_context_encoding be true for chunked prefill also

# (see kv_cache_manager.update_cache), so max_batch_size is preserved
# for KV cache sizing while the forward-pass batch dim is just 1.
cp_config = copy.deepcopy(neuron_config)
cp_config.batch_size = neuron_config.effective_prefill_batch_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather reuse here ctx_batch_size

# repeated position is a true no-op (avoids non-deterministic scatter
# behaviour with duplicate indices corrupting the cache).
is_real = torch.ones(1, chunk_size, dtype=torch.bool, device=device)
is_real[:, 1:] = position_ids[:1, 1:] > position_ids[:1, :-1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it assume that we are using a batch_size of 1 here ? I would have expected something like:
is_real[:, 1:] = position_ids[:, 1:] > position_ids[:, :-1]
Also add a comment saying like we attend to tokens whose position_ids is still increasing. Rename is_real to actual_tokens_mask. I am also wondering if this is fully consistent because of the view you need to apply afterwards.

# limitations under the License.
"""Tests verifying that chunked prefill produces equivalent output to standard context encoding."""

import os
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use the export fixture here, adding two more configs for llama

# Then process new prompt requests.
if n_prompt_reqs > 0:
(requests, input_ids, position_ids, seq_ids, sampling_params) = self._prepare_prompt(scheduler_output)
chunk_size = self.model.model.neuron_config.prefill_chunk_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to have the chunked_prefill reuse more code from the standard prefill path

_get_neuron_model_for_config(config_name, model_config, neuron_model_path)
progress.advance(task_id)
progress.update(task_id, description="[green]All models exported")
CHUNKED_PREFILL_MODEL_ID = "unsloth/Llama-3.2-1B-Instruct"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want an extra fixture here. You should reuse the existing fixture and extend it so that it supports chunked prefill. It could for instance be parameterized as <batch_size>x<num_chunks>x<chunk_size> so that std models are represented as <batch_size>x1x<sequence_length>

from vllm import LLM, SamplingParams


# Must match CHUNKED_PREFILL_CHUNK_SIZE in tests/fixtures/llm/export_models.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do that: sharing implicit knowledge is a bad practice: reuse the value from the fixture config

dacorvo and others added 4 commits March 9, 2026 17:39
- Remove `effective_prefill_batch_size` property, use `ctx_batch_size`
- Replace init-time `is_chunked_prefill_graph` flag with input-shape
  detection via `_is_chunked_prefill(input_ids)`
- Keep `is_for_context_encoding=True` for chunked prefill
- Rename `is_real` to `actual_tokens_mask`, fix position_ids indexing
  to handle all batch elements
- Move chunked prefill scatter under CE branch in KV cache manager
- Extract `_flat_scatter_update` helper method

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Extract `_pad_to_max_context_length` and `_pad_to_chunk_size` methods
  in decoder_wrappers.py
- Refactor `get_next_tokens` into `sample_next_tokens` — a pure sampling
  closure with forward calls made explicit in each execution mode
  (standard CE, chunked prefill, decode)
- Handle hybrid ODS: initialize `fused_logits_warper` when chunked
  prefill is used, pass `is_ods=False` for chunked prefill outputs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extract shared request collection logic into `_collect_new_requests`
helper, called by both `_prepare_prompt` and `_execute_chunked_prefill`.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add chunked prefill configs to GENERATE_LLM_MODEL_CONFIGURATIONS
- Remove chunked_prefill_llm_config fixture
- Decoder tests: use neuron_llm_config with indirect=True, compare
  chunked model against HF model with real prompts
- vLLM engine tests: use neuron_llm_config with indirect=True, test
  short/long/batch generation and batch consistency
- vLLM service test: use neuron_llm_config with indirect=True, verify
  service startup and generation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces chunked-prefill support across Optimum Neuron’s decoder backend and vLLM integration, along with new tests and export/benchmark utilities intended to validate and measure the behavior.

Changes:

  • Add chunked-prefill graph support (new config knob prefill_chunk_size) and route prompt prefill through chunk-by-chunk execution in both the Neuron generation path and vLLM runner.
  • Update attention masking/softmax to avoid bf16 exp() NaNs and support masked softmax for chunked/speculation-style masking.
  • Add new decoder + vLLM engine/service tests and extend test model export fixtures with chunked-prefill configurations.

Reviewed changes

Copilot reviewed 38 out of 42 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
optimum/neuron/vllm/runner.py Executes prompt prefill via chunked-prefill path when enabled; aligns token tensor shapes for ODS vs CPU sampling.
optimum/neuron/vllm/platform.py Adjusts scheduler token budget to avoid >2048-token prompt deadlocks on Neuron platform.
optimum/neuron/vllm/model_loader.py Adds a vLLM-facing prefill_chunk_vllm helper for chunk-by-chunk prefill logits.
optimum/neuron/models/inference/backend/config.py Introduces prefill_chunk_size and enforces incompatibility with speculation.
optimum/neuron/models/inference/backend/modules/decoder/* Adds chunked-prefill wrapper/tag, chunked graph builder, chunked KV scatter update path, and prefill_chunk() API.
optimum/neuron/models/inference/backend/modules/generation/generation_utils.py Implements chunked prefill loop and “hybrid ODS” sampling behavior.
optimum/neuron/models/inference/backend/modules/attention/* Adds masked manual_softmax and replaces finfo(min) masking with -1e9 to prevent NaNs.
optimum/exporters/neuron/__main__.py, optimum/commands/export/neuronx.py Plumbs prefill_chunk_size through export CLI and exporter entrypoint.
tests/fixtures/llm/export_models.py Adds std vs chunked model configs and improves export script UX.
tests/decoder/*, tests/vllm/* Adds/adjusts tests for chunked prefill and Neuron subprocess isolation for Neuron runtime state.
benchmark/vllm/* Adds benchmark results and improves benchmark runner script.
AGENTS.md, optimum/neuron/version.py Documentation update and dev version bump.

Comment on lines +1 to +35
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Engine-level tests verifying that a chunked-prefill model generates correct
output through the vLLM engine.

Each test exercises a distinct code path relative to the model's chunk_size:

short — chunk_size // 2 tokens → padding path in _execute_chunked_prefill
long — chunk_size * 2 tokens → KV accumulation across two complete chunks
batch — [short, long] in one call → verifies consistency with individual runs
"""

from typing import Any

import pytest


# Do not collect tests from this file if vllm is not installed
pytest.importorskip("vllm")


from vllm import LLM, SamplingParams

Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruff will likely flag E402 here because from vllm import ... appears after the pytest.importorskip("vllm") side-effect statement. Other vLLM engine tests silence this with a file-level # ruff: noqa: E402 header; please add the same (or otherwise restructure to satisfy E402 while keeping the importorskip behavior).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants