Add TensorRT-LLM support for inference backend#121
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ac2cd6db34
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| assert pp_size == 1, f"trtllm_pp_size must be 1, got {pp_size}" | ||
|
|
||
| if self.base_gpu_id is not None: | ||
| self.local_gpu_id = self.setup_gpu(self.base_gpu_id) |
There was a problem hiding this comment.
Set CUDA visibility before initializing CUDA
When inference_num_gpus_per_engine > 1 (including the added sample config), this calls setup_gpu(), which runs torch.cuda.set_device() while Ray still exposes only the actor's single scheduled GPU. _init_engine() expands CUDA_VISIBLE_DEVICES later, but CUDA has already been initialized, so TensorRT-LLM's multi-GPU constructor/device-count check can still see only one GPU and fail or start TP workers without the full assigned set. Set the engine's full GPU visibility before touching torch.cuda.
Useful? React with 👍 / 👎.
| self._engine = LLM( | ||
| model=self.args.target_model_path, | ||
| backend="pytorch", | ||
| tensor_parallel_size=tp_size, | ||
| trust_remote_code=getattr(self.args, "trust_remote_code", True), | ||
| speculative_config=spec_config, | ||
| **engine_kwargs, |
There was a problem hiding this comment.
Protect trust_remote_code from extra_args
If a user puts trust_remote_code in inference.trtllm.extra_args (a normal LLM kwarg and common when porting backend configs), it is not removed by _PROTECTED_ENGINE_KEYS, so this call passes trust_remote_code both explicitly and via **engine_kwargs and raises TypeError: got multiple values for keyword argument 'trust_remote_code' during engine initialization. Either add it to the protected set or move the explicit value into engine_kwargs before applying extras.
Useful? React with 👍 / 👎.
d8bc0d5 to
9342b09
Compare
23a9ed1 to
38bb53c
Compare
| echo "Extra args: $*" | ||
| echo "==============================================" | ||
|
|
||
| # TODO: unify tp_size config across sglang/vllm backends |
There was a problem hiding this comment.
Previously, vLLM hardcodes the config in its yaml (configs/vllm_qwen3_8b.yaml)
| **vLLM** | ||
|
|
||
| ```bash | ||
| ./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml |
There was a problem hiding this comment.
There is no --config field in run.sh . See run.sh docstring.
| TorchSpec explicitly uses are listed; any other ``LLM`` kwarg can be passed | ||
| via ``extra_args``. | ||
|
|
||
| Single-node tensor parallelism only (nnodes must be 1); multi-node TP is |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 38bb53c003
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| from tensorrt_llm import LLM, SamplingParams | ||
| from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig |
There was a problem hiding this comment.
Defer TensorRT-LLM imports until backend selection
When tensorrt-llm is installed in an environment where the NVIDIA driver/libcuda is not available, importing this module at package import time can raise a non-ImportError load failure; torchspec.inference.engine.__init__ imports it opportunistically for every backend and only catches ImportError, so unrelated HF/vLLM/SGLang imports can fail before TRT-LLM is selected. Move these imports into TrtllmEngine initialization or catch the CUDA loader failure in the optional import path.
Useful? React with 👍 / 👎.
| fi | ||
|
|
||
| # Derive the tp_size override block from the config's engine type ("sgl" -> "sglang"). | ||
| ENGINE_TYPE=$(grep -oE "inference_engine_type:[[:space:]]*[a-zA-Z]+" "$CONFIG_FILE" | awk '{print $2}') |
There was a problem hiding this comment.
Make engine-type probing non-fatal
With set -euo pipefail, any custom config that omits a literal unquoted inference_engine_type in the YAML, or supplies it via the extra CLI args, makes grep return 1 and exits the launcher before the ${ENGINE_TYPE:-sglang} fallback can run. This regresses the script's [CONFIG_FILE] [EXTRA_ARGS...] path for otherwise usable configs; make the probe tolerate no match or derive the value from the resolved config.
Useful? React with 👍 / 👎.
| + key=key, | ||
| + hidden_states=aux_hidden_states, | ||
| + input_ids=input_ids, | ||
| + last_hidden_states=last_hidden_states, |
There was a problem hiding this comment.
Honor disabled last-hidden-state storage
When inference.store_last_hidden_states=false, TrtllmEngine._get_tensor_shapes() omits last_hidden_states, so the data fetcher later removes only _hs and _ids; however this patched resource manager still always passes last_hidden_states to store.put() here. In that configuration every TRT-LLM sample leaves its _lhs object behind in Mooncake and long runs can fill the segment, so either skip storing last_hidden_states in this mode or include it in the returned metadata.
Useful? React with 👍 / 👎.
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
tensorrt_llm's package init loads libcuda.so.1, which is absent during the image build. Apply the patch in the fixed python3.12 dist-packages path and verify it landed by grepping the patched file instead of importing the module. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
TRT-LLM's MpiPoolSession._start_mpi_pool only forwards env vars prefixed
TRTLLM/TLLM to its spawned MPI workers. The SaveHiddenStates Mooncake
redirect gates on TORCHSPEC_TRTLLM_MOONCAKE and reads the MOONCAKE_*
connection params, none of which reached the workers, so hidden-state
capture silently fell back to disk mode and the trainer timed out waiting
for keys ("batch_get_buffer missing keys").
Add mpi_session.patch forwarding TORCHSPEC*/MOONCAKE*/MC_* prefixes
(applied by the existing patch loop) plus a build-time validation grep.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
The patch stored hidden states under request.py_request_id, the backend runtime request id which is bumped by JIT-warmup and internal requests. But the TorchSpec engine reconstructs Mooncake keys from RequestOutput.request_id == GenerationRequest.id, which TRT-LLM sets to the client_id assigned by the proxy (_get_next_client_id). The two id spaces are mapped via _client_id_to_request_id and differ, so the stored keys were systematically offset from the keys the trainer requested, surfacing as "Size mismatch for hidden_states" (data shifted between keys). Key on request.py_client_id, falling back to py_request_id if unset. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
SaveHiddenStates only captures hidden states for tokens that actually run a forward pass. With enable_block_reuse on (the TRT-LLM default), prompts that share a prefix reuse cached KV blocks and skip the forward for the shared tokens, so fewer tokens are captured than the prompt length. Always construct KvCacheConfig with enable_block_reuse=False and enable_partial_reuse=False (previously the config was only built when a mem fraction was set, leaving reuse on otherwise). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
SaveHiddenStates runs each prompt as a single prefill (chunked prefill is
disabled), so a prompt longer than max_num_tokens is rejected outright
("sum of prompt length ... should not exceed max_num_tokens") and the
sample is dropped. With max_num_tokens=8192 but max_seq_length=16384, samples
in that range were silently lost during training. Raise to 16384.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Collapse a NotImplementedError call onto one line (fits in the 100-char limit) so `ruff format --check` passes. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Merge the standalone mpi_session.patch (forwarding TORCHSPEC*/MOONCAKE*/MC_* env to MPI workers) into the main trtllm.patch and drop the separate file. Also relax the upstream max_batch_size=1 cap that TRT-LLM forces whenever SaveHiddenStatesDecodingConfig is active: when TORCHSPEC_TRTLLM_MOONCAKE is set, keep the configured max_batch_size so batched prefill capture works. The Mooncake resource manager de-interleaves the packed multi-request buffer, so the offset-0 single-request constraint no longer applies. Overlap scheduler and CUDA graphs stay disabled (shared capture buffer is overwritten every forward and must be read first). Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Derive the per-backend config block from inference_engine_type instead of hardcoding inference.sglang.tp_size. "sgl" maps to the "sglang" block; vllm/trtllm map 1:1. The same launcher now drives sglang, vllm, and trtllm with an identical inference layout for fair side-by-side runs. Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Switch the inference layout to 1 GPU / tp=1, drop the eval dataset, shrink the Mooncake global_segment_size to 16GB, and namespace cache_dir per example. Point the usage comment at the unified run.sh launcher. Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
Adopt the positional run.sh launch form in the README (vLLM + TRT-LLM examples) and drop the patch-verify grep steps from the trtllm Dockerfile, matching origin/chungen/trtllm. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
SaveHiddenStates' Mooncake path sliced per-request hidden states in py_batch_idx (= py_seq_slot, an arbitrary KV slot) order, but the forward packs tokens in scheduled_requests.context_requests native order. Under multi-request batching the orders differ, so per-request slices read the wrong hidden states and corrupt a subset of captured samples. Iterate context_requests in native order so the slices align. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
TRT-LLM's MPI tp-workers bind GPU ordinals 0..tp_size-1 regardless of the Ray placement group, so under the default training_first they collide with training on the low GPUs and OOM. inference_first assigns inference those low indices. SGLang honors base_gpu_id and is unaffected, so the override is TRT-config-only. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
trust_remote_code is passed explicitly to LLM() but was missing from _PROTECTED_ENGINE_KEYS. A user setting it in inference.trtllm.extra_args (a standard LLM kwarg) would have it survive the filter and be passed twice, raising "TypeError: got multiple values for keyword argument 'trust_remote_code'" at engine init. Add it to the protected set so the explicit config-sourced value wins, matching model/backend/tp_size. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Tighten TRT-only comments to the project's lean style and align with the sglang reference where applicable. No functional change. - Dockerfile: drop the known-issue CUDA-13 wheel comment; condense the patch-step comment to the one non-obvious gotcha (don't import tensorrt_llm at build time -- libcuda absent until runtime). - run.sh: collapse the tp_size-block derivation comment to one line. - inference_config.py: drop the redundant init_timeout comment (sglang/ vllm declare the same field without one). - trtllm_qwen3_8b.yaml: add the commented eval_data_path/eval_interval placeholders to match the sglang config; consolidate the max_num_tokens note to one line. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
When store_last_hidden_states=false (e.g. the DFlash configs), the engine
omits last_hidden_states from the returned tensor metadata, so the fetcher
never fetches or deletes the {key}_lhs object -- but the patched resource
manager still always wrote it, orphaning one _lhs per sample in Mooncake
and filling the segment over long runs. Gate storage on a new
TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN env (exported by the engine before LLM
construction, propagated to MPI workers) so _lhs is written only when the
flag is set. Default path (true) is unchanged.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
tensorrt_llm's package __init__ loads libcuda.so.1 on import. With the LLM /SamplingParams/KvCacheConfig/SaveHiddenStatesDecodingConfig imports at module scope, importing torchspec.inference.engine on a host without a CUDA driver raised a non-ImportError loader failure; engine/__init__.py only catches ImportError, so the whole engine package (HF/vLLM/SGLang included) failed before TRT-LLM was ever selected. Move the imports into the methods that use them, matching vllm_engine.py's lazy `from vllm import LLM`. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
The grep that derives the tp_size override block returns 1 when a config
has no literal inference_engine_type line (e.g. set via CLI extra args);
under set -euo pipefail that exited the launcher before the ${ENGINE_TYPE
:-sglang} fallback could run, regressing the [CONFIG_FILE] [EXTRA_ARGS...]
path. Append `|| true` so a no-match yields empty and falls back.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
…ent + per-engine keys)
Enables N independent single-GPU (tp=1) TRT-LLM inference engines on distinct
GPUs, needed for data-parallel inference (e.g. DFlash's 4 tp=1 replicas).
Two co-dependent pieces:
- GPU placement: single-GPU engines drop the NOSET override so Ray scopes
CUDA_VISIBLE_DEVICES per actor to its own GPU; the engine no longer manually
pins CVD for tp=1 (TRT's MPI worker maps cuda:0 to the scoped GPU). CVD is
set before any CUDA/tensorrt_llm init (also fixes premature-CUDA-init). TP>1
engines keep the prior NOSET + manual contiguous-block pin.
- Mooncake keys: each engine has its own request_id counter, so keys collided
in the shared store across data-parallel engines. Prefix keys per engine
(e{rank}_) on both the write side (patch) and the read side (engine).
Without these, 4 tp=1 engines all bound GPU 0 (OOM) and clobbered each other's
Mooncake keys (shape mismatch). Validated: DFlash trtllm 10k steps clean.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Relax the DFlash engine guard in train_entry.py from sgl-only to
{'sgl', 'trtllm'}, and add configs/trtllm_qwen3_8b_dflash.yaml (mirrors the
sglang DFlash config; trtllm inference block + inference_first placement).
DFlash's aux-layer wiring and the trainer are engine-agnostic, so no other
changes are needed for the trtllm path.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
dba14e9 to
dde1128
Compare
Drop the stale sgl-only guard NOTE from the trtllm dflash config (the guard now permits trtllm), remove a redundant CVD-scoping comment in trtllm_engine, and rewrite the factory GPU-assignment comment to explain the tp==1 vs tp>1 split without the dangling "Approach 1" reference that assumed prior context. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d29cd75573
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| + if store is None: | ||
| + return |
There was a problem hiding this comment.
Fail when Mooncake storage is unavailable
When the patched resource manager cannot create a Mooncake store (for example, bad Mooncake env, an incompatible wheel, or a setup failure), these lines return without storing any tensors, but TrtllmEngine.generate() still reconstructs and returns Mooncake keys from the request id. The training fetcher later tries to read missing _hs/_ids/_lhs objects, so the actual failure is delayed until after bogus samples have been queued; surface the storage error here instead of returning success.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Mooncake is required for this framework.
| # Installation: | ||
| # Use the TensorRT-LLM docker image (docker/trtllm/v1.3.0rc18/Dockerfile), | ||
| # which ships tensorrt_llm patched for Mooncake hidden-state capture. | ||
| # For a local install: pip install -e ".[trtllm]" |
There was a problem hiding this comment.
Do not document an unpatched local install
When users follow this local install path, pip install -e ".[trtllm]" only installs the TensorRT-LLM dependency; the Mooncake redirect patch under patches/trtllm/v1.3.0rc18/ is only applied by the Dockerfile. With an unpatched local TensorRT-LLM, TORCHSPEC_TRTLLM_MOONCAKE is ignored and hidden states are written to disk while TrtllmEngine returns Mooncake keys that were never stored, so either make the local install apply the patch or remove this instruction.
Useful? React with 👍 / 👎.
Signed-off-by: chungen04 <b09901027@ntu.edu.tw>
|
@yubofredwang In this PR the TensorRT-LLM engine was integrated as a backend for inference. Please review and I look forward to suggestions, split PR plans, etc. Also, I did not include tests currently. Please let me know if you suggest writing test in this PR or split one. Thank you! |
thanks for the great work! I will run it end to end to verify |
Summary
Integrates TensorRT-LLM (v1.3.0rc18) as an inference backend for hidden-state capture, alongside the existing vLLM and SGLang backends. Validated end-to-end against SGLang on Qwen3-8B (EAGLE-3 / DFlash).
Approach
TRT-LLM already ships the EAGLE3 capture machinery natively (
SaveHiddenStatesDecodingConfig/SaveHiddenStatesResourceManager), so this integration is thin. It builds on that mode and adds a small patch that redirects captured aux + final hidden states to Mooncake instead of writing .pt files to disk.Changes
patches/trtllm/v1.3.0rc18/trtllm.patch: env-gated(
TORCHSPEC_TRTLLM_MOONCAKE) Mooncake redirect insave_hidden_state.py.context_requestsin native order to match how the executor flattens tokens.max_batch_size=1cap for hidden state capture.torchspec/inference/engine/trtllm_engine.py:TrtllmEngine: Ray actor wrapping the PyTorch-backendLLMRequestOutput.request_idto Mooncake key using the same sanitizer as the patch.factory.py/ config:"trtllm"dispatch,TrtllmConfig,trtllm_flatten prefix.configs/trtllm_qwen3_8b.yaml.docker/trtllm/v1.3.0rc18/Dockerfile+justfilebranch +[trtllm]extraconfigs/trtllm_qwen3_8b.yaml+ README entry +qwen3-8b-single-nodeexample wired to the shared launcher.Performance
Setting: Qwen3-8B, single node (4xB300, 2 inference (TP=2) + 2 training FSDP), open-perfectblend dataset (not re-generated), global batch 8, 10k steps, 512 samples split for eval. Each backend evaluates on its own captured eval cache.
Quality is verified, identical (eval accuracy, matched at every checkpoint):
For E2E Speed, TRT-LLM is ~3.6× faster (training-loop wall-clock):
SGLang is inference-bound. TRT-LLM speeds up inference and becomes training-bound under this setting.
There is also a DFlash example added, mirroring SGLang's example (For SGLang's training, see #126 for an encountered hang). The training results shows quality remains equivalent while wall clock time also mirrors, as with both engine the pipeline was training-bound.
E2E speed for DFlash:
Training results for DFlash:
Test
Note
Multimodal input is not yet supported. I plan to split the multimodal input support and the base TRTLLM support into two PRs.