Conversation
The first change is to defer the annotation evaluations to class instantiations by importing annotations, and the second to only import sentence_transformers symbols when the package is installed.
Loading a compiled model onto NeuronCores permanently sets the Neuron runtime's global communicator world_size for the lifetime of the process. Tests that use build_module (tp_degree=1) poison the runtime so that later tests requiring tp_degree=2 fail with "World size of neff N is greater than world size of global communicator M". Add @subprocess_test decorator that re-invokes each test via pytest in a fresh subprocess, preventing NRT state from leaking into the parent session. Also change _save_checkpoint() to use the gloo backend instead of xla, avoiding premature NRT initialization during checkpoint creation. Add conftest.py with pytest_collection_modifyitems hook to reorder @subprocess_test tests before session-scoped fixtures that hold NeuronCores. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When running `python tests/fixtures/llm/export_models.py` directly, show a compact live display with a scrolling 10-line log window, the current model configuration being compiled, and an overall progress bar. Falls back to the plain loop if rich is not installed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- `--list`: print configuration names with model IDs and exit - positional glob pattern filters which configs to export (e.g. `'gemma*'`, `'*-1x8192'`) - no-match prints available names and exits with error Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ron XLA Move mask handling inside manual_softmax: masked positions are excluded from max/exp/sum and receive zero probability. This avoids feeding extreme fill values (finfo.min) through exp() which produces NaN on Neuron XLA's vectorized implementation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace finfo(bf16).min with -1e9 as the masked-fill value in scaled_qk(): Neuron XLA's exp() produces NaN for inputs near the dtype minimum, while exp(-1e9) safely underflows to 0.0 in f32.
Verify that manual_softmax produces equivalent results on NeuronCores versus CPU, using dimensions representative of Llama-3.2-1B (8 heads, head_dim=64) and Gemma3-1B (4 heads, head_dim=256, sliding_window=512). Tests both the masked path (with boolean prior_mask) and the unmasked legacy path, each parametrized over both model configurations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds support for splitting long prompts into fixed-size chunks processed sequentially through the KV-cache scatter path, eliminating the need for a large context-encoding compilation graph. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When a chunk has fewer real tokens than chunk_size, the remaining positions are padded with repeated position_ids. The old lower- triangular active_mask let padded tokens attend to more context than the real last token, producing different KV values. With torch.scatter writing duplicate indices (all padded tokens → same cache position), the non-deterministic write order on Neuron corrupted the KV cache. Derive an is_real column mask from position_ids (strictly increasing → real token) and AND it with the causal mask so padded tokens attend to exactly the same set as the real last token. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Two tests sharing one module-scoped fixture (compiles both models once): - test_chunked_prefill_generates_same_tokens[short_context]: 32-token prompt (< CHUNK_SIZE=64), exercises the repeat-last-token padding path in NxDDecoderWrapperForCausalLM. - test_chunked_prefill_generates_same_tokens[long_context]: 256-token prompt (= 4 × CHUNK_SIZE), exercises multi-chunk KV accumulation. Both parametrizations assert exact greedy-token equality between chunked prefill and standard context encoding. On failure the test decodes both outputs via AutoTokenizer so the divergence is human-readable. - test_chunked_prefill_graph_structure: asserts the chunked model has chunked_prefill_model (not context_encoding_model) alongside token_generation_model. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When the loaded model has prefill_chunk_size > 0, the runner now processes new prompts chunk-by-chunk instead of a single full-sequence context-encoding pass. Each sequence is processed independently, one at a time — processing one sequence at a time is not only simpler but also faster and consumes less device memory than batching multiple sequences (benchmarked on trn1.32xlarge with Llama 3.1 8B). model_loader.py — OptimumNeuronModelForCausalLM: - Add prefill_chunk_vllm(input_ids, position_ids, seq_ids, sampling_params) which delegates to NxDModelForCausalLM.prefill_chunk() and squeezes the [1, 1, vocab] output to [1, vocab] for the sampler. runner.py — OptimumNeuronModelRunnerForCausalLM: - execute_model(): when prefill_chunk_size > 0 and there are new prompt requests, dispatch to _execute_chunked_prefill() instead of _prepare_prompt() + get_next_tokens(). - _execute_chunked_prefill(): registers each new request in the batch, then processes each sequence independently one at a time. For each sequence, iterates over ceil(seq_len / chunk_size) rounds, building [1, chunk_size] input_ids / position_ids tensors (padded with repeat-last-token when shorter than chunk_size). Calls prefill_chunk_vllm() each round and CPU-samples from the final round's logits via NeuronSampler. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
vLLM V1 defaults enable_chunked_prefill=True which sets max_num_batched_tokens to 2048. The engine core later disables chunked prefill (empty kv_cache_groups — Neuron manages KV cache internally), leaving a 2048-token budget with no chunking. Any prompt >2048 tokens is permanently stuck in the scheduler waiting queue. Override to max_model_len: both standard and chunked prefill process one sequence at a time, so the budget should be at least max_model_len. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds a shared session fixture and two test modules that verify the chunked-prefill vLLM integration end-to-end. fixtures/llm/export_models.py: - Add CHUNKED_PREFILL_MODEL_ID / _SEQUENCE_LENGTH / _CHUNK_SIZE / _BATCH_SIZE constants. - Add chunked_prefill_llm_config session fixture: exports and hub-caches both a standard (context-encoding) and a chunked-prefill Llama-3.2-1B model, each with batch_size=2, sequence_length=4096, chunk_size=512, on_device_sampling=False, fused_qkv=True. tests/vllm/engine/test_vllm_engine_chunked_prefill.py: - Module-scoped fixture loads both models as LLM instances. - test_chunked_prefill_engine_generates_same_tokens[short]: prompt = chunk_size // 2 tokens, exercises the padding path in _execute_chunked_prefill. - test_chunked_prefill_engine_generates_same_tokens[long]: prompt = chunk_size * 2 tokens, exercises two-chunk KV accumulation. - test_chunked_prefill_engine_batch: sends [short, long] in one call, exercising the "repeat last real token for exhausted sequence" path in the second chunk round. tests/vllm/service/test_vllm_service_chunked_prefill.py: - Module-scoped fixtures launch the standard and chunked models as separate vLLM HTTP services via vllm_launcher. - test_vllm_service_chunked_prefill_generation: verifies the chunked service starts and returns the expected number of tokens. - test_vllm_service_std_vs_chunked_prefill: verifies that greedy text output is identical between the two services, confirming the runner's _execute_chunked_prefill path matches the standard context-encoding path end-to-end through the HTTP API. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Both models exported with identical params (BS=32, SL=4096, TP=8, on_device_sampling=false) — std from main (0.4.5.dev2), chunked from feat branch (chunk_size=1024). Benchmarked with 32 concurrent users. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
On-device sampling (ODS) halves inter-token latency by running the sampler on NeuronCores instead of CPU, but it was previously disabled for all chunked prefill models because the prefill graph must return logits for CPU sampling. Hybrid ODS solves this by compiling two graphs with different sampling modes: - Token generation graph: ODS enabled (returns sampled tokens) - Chunked prefill graph: ODS disabled (returns logits for CPU sampling) Each graph builder already creates its own NxDDecoder instance with a separate neuron_config, so the change is minimal: - _create_chunked_prefill_config() forces on_device_sampling=False - The incompatibility check is removed from NxDNeuronConfig - The vLLM runner creates a CPU sampler when chunked prefill is active - ODS output is unsqueezed to [batch, 1] to match CPU sampler shape Benchmark results (Llama 3.1 8B, BS=32, SL=4096, TP=8, 32 users): Config A (std CE, ODS=true): 426 tok/s, 59ms ITL Config C (chunked, ODS=false): 190 tok/s, 119ms ITL Config D (chunked, hybrid): 440 tok/s, 56ms ITL (+132% vs C) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add benchmark results for the full 4-way comparison matrix: A: std CE, ODS=true (426 tok/s, 59ms ITL) D: chunked, hybrid (440 tok/s, 56ms ITL) Config D (hybrid ODS) matches or exceeds the production baseline (A) while providing chunked prefill's memory advantages. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…euse When processing multiple sequences sequentially, the NxD runtime reuses output buffers between calls. Without .clone(), previously stored logits are silently overwritten, causing token shift corruption in batched output. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
dacorvo
left a comment
There was a problem hiding this comment.
This looks good to me, but I would like some of the commits to be amended to make it easier to maintain.
| # scatter is a no-op overwrite with the same value — no corruption. | ||
| last_input_id = input_ids[:, -1:] # [batch, 1] | ||
| last_pos = position_ids[:, -1:] # [batch, 1] | ||
| input_ids = torch.cat([input_ids, last_input_id.expand(-1, pad_length)], dim=-1) |
There was a problem hiding this comment.
You should reuse pad_to_max_context_length for consistency
| f"Adjusting num_key_value_heads from {config.num_key_value_heads} to {num_key_value_heads} for TP {neuron_config.tp_degree}." | ||
| ) | ||
| self.kv_mgr = KVCacheManager(config, neuron_config, actual_num_key_value_heads=num_key_value_heads) | ||
| self.is_chunked_prefill_graph = neuron_config.prefill_chunk_size > 0 |
There was a problem hiding this comment.
I'd rather not use a flag set at init. All other graphs are identified when relevant based on the neuron_config or input shapes
| sampling_params (torch.FloatTensor): Sampling parameters. | ||
| """ | ||
| is_for_context_encoding = self._is_context_encoding(input_ids) | ||
| is_for_speculation = self._is_for_speculation(input_ids) |
There was a problem hiding this comment.
Here you should insert a flag for chunked_prefill. Ideally it should not be incompatible with is_for_context_encoding
|
|
||
| hidden_size = hidden_states.shape[-1] | ||
| if is_for_context_encoding: | ||
| if is_for_context_encoding or self.is_chunked_prefill_graph: |
There was a problem hiding this comment.
I'd rather have is_for_context_encoding be true for chunked prefill also
| # (see kv_cache_manager.update_cache), so max_batch_size is preserved | ||
| # for KV cache sizing while the forward-pass batch dim is just 1. | ||
| cp_config = copy.deepcopy(neuron_config) | ||
| cp_config.batch_size = neuron_config.effective_prefill_batch_size |
There was a problem hiding this comment.
I'd rather reuse here ctx_batch_size
| # repeated position is a true no-op (avoids non-deterministic scatter | ||
| # behaviour with duplicate indices corrupting the cache). | ||
| is_real = torch.ones(1, chunk_size, dtype=torch.bool, device=device) | ||
| is_real[:, 1:] = position_ids[:1, 1:] > position_ids[:1, :-1] |
There was a problem hiding this comment.
Doesn't it assume that we are using a batch_size of 1 here ? I would have expected something like:
is_real[:, 1:] = position_ids[:, 1:] > position_ids[:, :-1]
Also add a comment saying like we attend to tokens whose position_ids is still increasing. Rename is_real to actual_tokens_mask. I am also wondering if this is fully consistent because of the view you need to apply afterwards.
| # limitations under the License. | ||
| """Tests verifying that chunked prefill produces equivalent output to standard context encoding.""" | ||
|
|
||
| import os |
There was a problem hiding this comment.
I'd rather use the export fixture here, adding two more configs for llama
| # Then process new prompt requests. | ||
| if n_prompt_reqs > 0: | ||
| (requests, input_ids, position_ids, seq_ids, sampling_params) = self._prepare_prompt(scheduler_output) | ||
| chunk_size = self.model.model.neuron_config.prefill_chunk_size |
There was a problem hiding this comment.
I would like to have the chunked_prefill reuse more code from the standard prefill path
tests/fixtures/llm/export_models.py
Outdated
| _get_neuron_model_for_config(config_name, model_config, neuron_model_path) | ||
| progress.advance(task_id) | ||
| progress.update(task_id, description="[green]All models exported") | ||
| CHUNKED_PREFILL_MODEL_ID = "unsloth/Llama-3.2-1B-Instruct" |
There was a problem hiding this comment.
I don't want an extra fixture here. You should reuse the existing fixture and extend it so that it supports chunked prefill. It could for instance be parameterized as <batch_size>x<num_chunks>x<chunk_size> so that std models are represented as <batch_size>x1x<sequence_length>
| from vllm import LLM, SamplingParams | ||
|
|
||
|
|
||
| # Must match CHUNKED_PREFILL_CHUNK_SIZE in tests/fixtures/llm/export_models.py |
There was a problem hiding this comment.
Don't do that: sharing implicit knowledge is a bad practice: reuse the value from the fixture config
- Remove `effective_prefill_batch_size` property, use `ctx_batch_size` - Replace init-time `is_chunked_prefill_graph` flag with input-shape detection via `_is_chunked_prefill(input_ids)` - Keep `is_for_context_encoding=True` for chunked prefill - Rename `is_real` to `actual_tokens_mask`, fix position_ids indexing to handle all batch elements - Move chunked prefill scatter under CE branch in KV cache manager - Extract `_flat_scatter_update` helper method Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Extract `_pad_to_max_context_length` and `_pad_to_chunk_size` methods in decoder_wrappers.py - Refactor `get_next_tokens` into `sample_next_tokens` — a pure sampling closure with forward calls made explicit in each execution mode (standard CE, chunked prefill, decode) - Handle hybrid ODS: initialize `fused_logits_warper` when chunked prefill is used, pass `is_ods=False` for chunked prefill outputs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extract shared request collection logic into `_collect_new_requests` helper, called by both `_prepare_prompt` and `_execute_chunked_prefill`. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add chunked prefill configs to GENERATE_LLM_MODEL_CONFIGURATIONS - Remove chunked_prefill_llm_config fixture - Decoder tests: use neuron_llm_config with indirect=True, compare chunked model against HF model with real prompts - vLLM engine tests: use neuron_llm_config with indirect=True, test short/long/batch generation and batch consistency - vLLM service test: use neuron_llm_config with indirect=True, verify service startup and generation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces chunked-prefill support across Optimum Neuron’s decoder backend and vLLM integration, along with new tests and export/benchmark utilities intended to validate and measure the behavior.
Changes:
- Add chunked-prefill graph support (new config knob
prefill_chunk_size) and route prompt prefill through chunk-by-chunk execution in both the Neuron generation path and vLLM runner. - Update attention masking/softmax to avoid bf16
exp()NaNs and support masked softmax for chunked/speculation-style masking. - Add new decoder + vLLM engine/service tests and extend test model export fixtures with chunked-prefill configurations.
Reviewed changes
Copilot reviewed 38 out of 42 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
optimum/neuron/vllm/runner.py |
Executes prompt prefill via chunked-prefill path when enabled; aligns token tensor shapes for ODS vs CPU sampling. |
optimum/neuron/vllm/platform.py |
Adjusts scheduler token budget to avoid >2048-token prompt deadlocks on Neuron platform. |
optimum/neuron/vllm/model_loader.py |
Adds a vLLM-facing prefill_chunk_vllm helper for chunk-by-chunk prefill logits. |
optimum/neuron/models/inference/backend/config.py |
Introduces prefill_chunk_size and enforces incompatibility with speculation. |
optimum/neuron/models/inference/backend/modules/decoder/* |
Adds chunked-prefill wrapper/tag, chunked graph builder, chunked KV scatter update path, and prefill_chunk() API. |
optimum/neuron/models/inference/backend/modules/generation/generation_utils.py |
Implements chunked prefill loop and “hybrid ODS” sampling behavior. |
optimum/neuron/models/inference/backend/modules/attention/* |
Adds masked manual_softmax and replaces finfo(min) masking with -1e9 to prevent NaNs. |
optimum/exporters/neuron/__main__.py, optimum/commands/export/neuronx.py |
Plumbs prefill_chunk_size through export CLI and exporter entrypoint. |
tests/fixtures/llm/export_models.py |
Adds std vs chunked model configs and improves export script UX. |
tests/decoder/*, tests/vllm/* |
Adds/adjusts tests for chunked prefill and Neuron subprocess isolation for Neuron runtime state. |
benchmark/vllm/* |
Adds benchmark results and improves benchmark runner script. |
AGENTS.md, optimum/neuron/version.py |
Documentation update and dev version bump. |
| # coding=utf-8 | ||
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Engine-level tests verifying that a chunked-prefill model generates correct | ||
| output through the vLLM engine. | ||
|
|
||
| Each test exercises a distinct code path relative to the model's chunk_size: | ||
|
|
||
| short — chunk_size // 2 tokens → padding path in _execute_chunked_prefill | ||
| long — chunk_size * 2 tokens → KV accumulation across two complete chunks | ||
| batch — [short, long] in one call → verifies consistency with individual runs | ||
| """ | ||
|
|
||
| from typing import Any | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| # Do not collect tests from this file if vllm is not installed | ||
| pytest.importorskip("vllm") | ||
|
|
||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
There was a problem hiding this comment.
Ruff will likely flag E402 here because from vllm import ... appears after the pytest.importorskip("vllm") side-effect statement. Other vLLM engine tests silence this with a file-level # ruff: noqa: E402 header; please add the same (or otherwise restructure to satisfy E402 while keeping the importorskip behavior).
What does this PR do?
Base pull-request to review claude code without triggering the CI