From 76ccb7928fb65f28dc7032fce7244bdb25912684 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Mon, 15 Jun 2026 23:39:17 +0000 Subject: [PATCH 01/28] Patches for trt Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/trtllm.patch | 166 +++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 patches/trtllm/v1.3.0rc18/trtllm.patch diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch new file mode 100644 index 0000000..be5f682 --- /dev/null +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -0,0 +1,166 @@ +diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py +index 05cf2610..b8e8145a 100644 +--- a/tensorrt_llm/_torch/speculative/save_hidden_state.py ++++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py +@@ -1,10 +1,12 @@ + import os ++import re + from dataclasses import dataclass + from typing import TYPE_CHECKING, List, Optional, Set + + import torch + + from tensorrt_llm._utils import local_mpi_rank ++from tensorrt_llm.logger import logger + + from ..pyexecutor.llm_request import LlmRequest + from ..pyexecutor.resource_manager import BaseResourceManager +@@ -15,6 +17,25 @@ if TYPE_CHECKING: + from ...llmapi.llm_args import SaveHiddenStatesDecodingConfig + + ++# TorchSpec integration: when this env var is set, SaveHiddenStates mode writes ++# captured hidden states directly to a Mooncake store (keyed by request id) ++# instead of accumulating .pt files on disk. This mirrors the sglang/vllm ++# TorchSpec backends so the EAGLE3 draft trainer can pull tensors over RDMA. ++_TORCHSPEC_MOONCAKE_ENV = "TORCHSPEC_TRTLLM_MOONCAKE" ++ ++ ++def _sanitize_mooncake_key(key: str) -> str: ++ """Match the key sanitization the TorchSpec engine uses to reconstruct keys. ++ ++ Keys must be reversible from ``RequestOutput.request_id`` on the engine ++ side, so this MUST stay in sync with ``TrtllmEngine``. ++ """ ++ sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) ++ if sanitized and sanitized[0].isdigit(): ++ sanitized = "k" + sanitized ++ return sanitized ++ ++ + class SaveHiddenStatesResourceManager(BaseResourceManager): + """ + Resource manager for SaveHiddenStates mode. +@@ -45,7 +66,13 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): + dtype=dtype, + device='cuda') + +- os.makedirs(self._output_directory, exist_ok=True) ++ # TorchSpec Mooncake redirect (opt-in via env var). When enabled we do ++ # not touch the output directory; tensors stream to Mooncake instead. ++ self._mooncake_mode = bool(os.environ.get(_TORCHSPEC_MOONCAKE_ENV)) ++ self._mooncake_store = None ++ self._mooncake_setup_done = False ++ if not self._mooncake_mode: ++ os.makedirs(self._output_directory, exist_ok=True) + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + for req in scheduled_batch.all_requests(): +@@ -73,6 +100,9 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): + scheduled_requests: The scheduled requests for this iteration + spec_metadata: The spec metadata containing layers_to_capture info + """ ++ if self._mooncake_mode: ++ self._process_and_store_mooncake(scheduled_requests, spec_metadata) ++ return + for request in sorted( + scheduled_requests.context_requests, + key=lambda r: +@@ -83,6 +113,97 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): + self._write_to_file() + self._iter += 1 + ++ def _ensure_mooncake_store(self): ++ """Lazily create the TorchSpec EagleMooncakeStore on local rank 0. ++ ++ Returns the store, or None if unavailable (callers then skip storing). ++ Mooncake connection params come from env vars exported by the ++ TorchSpec engine before the LLM is constructed. ++ """ ++ if self._mooncake_setup_done: ++ return self._mooncake_store ++ self._mooncake_setup_done = True ++ try: ++ from torchspec.config.mooncake_config import MooncakeConfig ++ from torchspec.transfer.mooncake.eagle_store import \ ++ EagleMooncakeStore ++ ++ config = MooncakeConfig.from_env() ++ store = EagleMooncakeStore(config) ++ device = None ++ if torch.cuda.is_initialized(): ++ device = torch.device(f"cuda:{torch.cuda.current_device()}") ++ store.setup(device=device) ++ self._mooncake_store = store ++ logger.info( ++ f"SaveHiddenStates: Mooncake store initialized " ++ f"(master={config.master_server_address})") ++ except Exception: ++ logger.error( ++ "SaveHiddenStates: failed to init Mooncake store; hidden " ++ "states will NOT be stored.", ++ exc_info=True) ++ self._mooncake_store = None ++ return self._mooncake_store ++ ++ def _process_and_store_mooncake( ++ self, scheduled_requests: ScheduledRequests, ++ spec_metadata: "SaveHiddenStatesSpecMetadata") -> None: ++ """Stream per-request hidden states to Mooncake instead of disk. ++ ++ The capture buffer holds every context token of this forward packed ++ from offset 0 in scheduled order, so we walk the requests in the same ++ order and slice ``[token_offset : token_offset + num_tokens]`` for ++ each. (Upstream's disk path assumes a single request and always reads ++ from offset 0 — incorrect for the batched prefill TorchSpec runs.) ++ ++ NOTE: assumes the forward packs context-request tokens in the same ++ order as this sort (by ``py_batch_idx``). Validate against real ++ batched runs during engine bring-up. ++ """ ++ if local_mpi_rank() != 0: ++ return ++ store = self._ensure_mooncake_store() ++ if store is None: ++ return ++ ++ token_offset = 0 ++ for request in sorted( ++ scheduled_requests.context_requests, ++ key=lambda r: ++ (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ++ ): ++ token_ids = list(request.get_tokens(0)) ++ num_tokens = len(token_ids) ++ buf = self.hidden_states[token_offset:token_offset + num_tokens] ++ token_offset += num_tokens ++ ++ # Final post-norm state is the last captured layer (the spec ++ # metadata moves layer -1 to the end of the capture buffer); ++ # everything before it is the concatenated aux layers. ++ last_hidden_states = buf[:, -self.hidden_size:].contiguous() ++ aux_hidden_states = buf[:, :-self.hidden_size].contiguous() ++ input_ids = torch.tensor(token_ids, ++ dtype=torch.long, ++ device=buf.device) ++ ++ key = _sanitize_mooncake_key(str(request.py_request_id)) ++ try: ++ store.put( ++ key=key, ++ hidden_states=aux_hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ target=None, ++ ) ++ except Exception: ++ logger.error( ++ f"SaveHiddenStates: failed to store request " ++ f"{request.py_request_id} to Mooncake", ++ exc_info=True) ++ ++ store.flush() ++ + def _process_request(self, request: LlmRequest, + spec_metadata: "SaveHiddenStatesSpecMetadata") -> None: + if local_mpi_rank() != 0: From 071394f71a0c3940eca5f69049faffbf79aadaf3 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Mon, 15 Jun 2026 23:56:01 +0000 Subject: [PATCH 02/28] Exported trtllm as engine class Signed-off-by: chungen04 --- torchspec/inference/engine/__init__.py | 9 + torchspec/inference/engine/trtllm_engine.py | 484 ++++++++++++++++++++ 2 files changed, 493 insertions(+) create mode 100644 torchspec/inference/engine/trtllm_engine.py diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index 2582441..635b4f6 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -55,3 +55,12 @@ from torchspec.utils.logging import logger as _logger _logger.debug("VllmEngine not available: %s", _e) + +try: + from torchspec.inference.engine.trtllm_engine import TrtllmEngine # noqa: F401 + + __all__.append("TrtllmEngine") +except ImportError as _e: + from torchspec.utils.logging import logger as _logger + + _logger.debug("TrtllmEngine not available: %s", _e) diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py new file mode 100644 index 0000000..5d58dfd --- /dev/null +++ b/torchspec/inference/engine/trtllm_engine.py @@ -0,0 +1,484 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +TensorRT-LLM Ray actor engine for distributed deployment. + +Wraps TensorRT-LLM's PyTorch-backend ``LLM`` running in ``SaveHiddenStates`` +speculative mode. TRT-LLM natively captures EAGLE3 aux-layer + final post-norm +hidden states into a per-forward buffer; TorchSpec's patch +(``patches/trtllm/.../trtllm.patch``) redirects that buffer to Mooncake instead +of ``.pt`` files when ``TORCHSPEC_TRTLLM_MOONCAKE`` is set. + +This engine therefore only has to: + 1. configure ``SaveHiddenStatesDecodingConfig`` with the right capture layers, + 2. flip the env flag and export the Mooncake connection params, and + 3. map each ``RequestOutput.request_id`` back to the Mooncake key the patch + wrote (using the SAME sanitization as the patch). + +Scope: single-node tensor parallelism. TRT-LLM spawns its own MPI workers, so +multi-node TP needs additional orchestration and is intentionally deferred. +""" + +import gc +import os +import re +import socket +import tempfile +from typing import Any + +import ray +import torch +from omegaconf import DictConfig, OmegaConf +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig + +from torchspec.inference.engine.base import InferenceEngine +from torchspec.ray.ray_actor import RayActor +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE +from torchspec.utils.logging import logger, setup_file_logging +from torchspec.utils.misc import get_default_eagle3_aux_layer_ids + +# Keys managed internally by TorchSpec — ignored if present in trtllm_extra_args. +_PROTECTED_ENGINE_KEYS = frozenset( + { + "model", + "backend", + "tensor_parallel_size", + "pipeline_parallel_size", + "speculative_config", + "kv_cache_config", + "disable_overlap_scheduler", + "enable_chunked_prefill", + } +) + +# Env flag the patched SaveHiddenStatesResourceManager gates on. +_TORCHSPEC_MOONCAKE_ENV = "TORCHSPEC_TRTLLM_MOONCAKE" + + +def _sanitize_mooncake_key(key: str) -> str: + """Reconstruct the Mooncake key the patch wrote for a request. + + MUST stay in sync with ``_sanitize_mooncake_key`` in + ``patches/trtllm//trtllm.patch``. + """ + sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) + if sanitized and sanitized[0].isdigit(): + sanitized = "k" + sanitized + return sanitized + + +class TrtllmEngine(InferenceEngine, RayActor): + """Ray actor wrapping TensorRT-LLM's PyTorch ``LLM`` in SaveHiddenStates mode. + + Accepts pre-tokenized input_ids or formatted prompt strings, runs prefill, + and returns Mooncake keys + tensor metadata (the hidden states themselves are + streamed to Mooncake by the patched resource manager). + """ + + def __init__( + self, + args, + rank: int, + base_gpu_id: int | None = None, + num_gpus_per_engine: int = 1, + node_rank: int = 0, + engine_group: int = 0, + ): + self.args = args + self.rank = rank + self.base_gpu_id = base_gpu_id + self.num_gpus_per_engine = num_gpus_per_engine + self.node_rank = node_rank + self._engine = None + self._mooncake_config = None + self._hidden_size = None + self.local_gpu_id = None + self.aux_hidden_state_layer_ids: list[int] = [] + self._store_last_hidden_states = True + + setup_file_logging("inference", self.rank, group=engine_group) + + def init( + self, + mooncake_config=None, + dist_init_addr: str | None = None, + pre_allocated_port: int | None = None, + ) -> None: + # TRT-LLM manages cross-worker init internally over MPI; dist_init_addr / + # pre_allocated_port are accepted for interface parity with the other + # engines but unused for single-node TP. + del dist_init_addr, pre_allocated_port + + nnodes = getattr(self.args, "trtllm_nnodes", 1) + if nnodes > 1: + raise NotImplementedError( + "TrtllmEngine currently supports single-node TP only " + f"(trtllm_nnodes={nnodes}). Multi-node TP is not yet wired up." + ) + pp_size = getattr(self.args, "trtllm_pp_size", 1) + assert pp_size == 1, f"trtllm_pp_size must be 1, got {pp_size}" + + if self.base_gpu_id is not None: + self.local_gpu_id = self.setup_gpu(self.base_gpu_id) + logger.info( + f"TrtllmEngine rank {self.rank}: base_gpu_id={self.base_gpu_id}, " + f"using local GPU {self.local_gpu_id}" + ) + + self._store_last_hidden_states = getattr(self.args, "store_last_hidden_states", True) + self._mooncake_config = mooncake_config + self._setup_mooncake_env(mooncake_config) + + self._hidden_size = self._get_hidden_size_from_engine() + self.aux_hidden_state_layer_ids = self._resolve_aux_layer_ids() + + tp_size = self.num_gpus_per_engine + mem_fraction = getattr(self.args, "trtllm_mem_fraction_static", 0.8) + + logger.info( + f"TrtllmEngine rank {self.rank}: BEFORE init - " + f"base_gpu_id={self.base_gpu_id}, tp_size={tp_size}, " + f"aux_hidden_state_layer_ids={self.aux_hidden_state_layer_ids}, " + f"hidden_size={self._hidden_size}" + ) + + self._init_engine(tp_size, mem_fraction) + + logger.info( + f"TrtllmEngine rank {self.rank}: initialized from {self.args.target_model_path} " + f"(tp_size={tp_size}, aux_layers={self.aux_hidden_state_layer_ids}, " + f"hidden_size={self._hidden_size})" + ) + + def _setup_mooncake_env(self, mooncake_config) -> None: + """Export Mooncake env + the redirect flag so MPI workers inherit them. + + TRT-LLM spawns its workers from this process, so any env set here before + constructing ``LLM`` is visible to the patched resource manager running + inside those workers. + """ + # Always set the flag; the patch only stores when a Mooncake store is + # actually reachable, so a missing master simply logs and skips. + os.environ[_TORCHSPEC_MOONCAKE_ENV] = "1" + + if mooncake_config is None: + logger.warning( + f"TrtllmEngine rank {self.rank}: no mooncake_config provided; " + "hidden states will NOT be stored." + ) + return + + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + except Exception: + local_ip = "localhost" + logger.warning( + f"TrtllmEngine rank {self.rank}: failed to get local IP, using localhost" + ) + + mooncake_config.local_hostname = local_ip + mooncake_config.export_env() + + from torchspec.transfer.mooncake.utils import check_mooncake_master_available + + check_mooncake_master_available( + mooncake_config.master_server_address, + mooncake_config.metadata_server, + ) + + def _resolve_aux_layer_ids(self) -> list[int]: + """Aux capture layers (post-layer indices, no +1 shift unlike vLLM). + + TRT-LLM's capture hook fires with ``self.layer_idx`` *after* each + decoder layer runs, so the layer ids map directly (same convention as + sglang). The final post-norm state is requested separately via the + ``-1`` entry added in ``_init_engine``. + """ + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained( + self.args.target_model_path, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + ) + cfg = getattr(cfg, "text_config", cfg) + num_layers = cfg.num_hidden_layers + + if self.args.aux_hidden_states_layers is not None: + aux_ids = list(self.args.aux_hidden_states_layers) + else: + aux_ids = get_default_eagle3_aux_layer_ids(self.args.target_model_path) + if self.rank == 0: + logger.info(f"Using default aux hidden state layer ids: {aux_ids}") + + aux_ids = [lid for lid in aux_ids if 0 <= lid < num_layers] + return aux_ids + + def _init_engine(self, tp_size: int, mem_fraction: float | None) -> None: + """Construct the TRT-LLM PyTorch ``LLM`` in SaveHiddenStates mode.""" + # Pin TRT-LLM's MPI workers to the assigned physical GPUs. Workers map + # their local rank onto the visible devices, so without this they would + # collide on devices 0..tp_size-1. + if self.base_gpu_id is not None: + gpu_ids = [str(self.base_gpu_id + i) for i in range(self.num_gpus_per_engine)] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids) + logger.info( + f"TrtllmEngine rank {self.rank}: set " + f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}" + ) + + # eagle3_layers_to_capture: aux layers + the final post-norm state (-1). + # The resource manager orders -1 last in the capture buffer, which is the + # split point the patch relies on (aux = [:, :-H], last = [:, -H:]). + layers_to_capture = set(self.aux_hidden_state_layer_ids) | {-1} + + # output_directory is a required field on the config even though the + # Mooncake redirect never writes to disk; point it at a throwaway path. + spec_config = SaveHiddenStatesDecodingConfig( + output_directory=os.path.join(tempfile.gettempdir(), "torchspec_trtllm_unused"), + eagle3_layers_to_capture=layers_to_capture, + ) + + engine_kwargs: dict[str, Any] = {} + + extra_args = getattr(self.args, "trtllm_extra_args", None) + if extra_args: + if isinstance(extra_args, DictConfig): + extra = OmegaConf.to_container(extra_args, resolve=True) + else: + extra = dict(extra_args) if not isinstance(extra_args, dict) else extra_args + blocked = extra.keys() & _PROTECTED_ENGINE_KEYS + if blocked: + logger.warning( + f"trtllm extra_args contains protected keys that will be ignored: " + f"{sorted(blocked)}. These are managed internally by TorchSpec." + ) + extra = {k: v for k, v in extra.items() if k not in _PROTECTED_ENGINE_KEYS} + engine_kwargs.update(extra) + + if mem_fraction is not None: + engine_kwargs["kv_cache_config"] = KvCacheConfig( + free_gpu_memory_fraction=mem_fraction + ) + + max_seq_length = getattr(self.args, "max_seq_length", None) + if max_seq_length: + engine_kwargs.setdefault("max_seq_len", max_seq_length) + + inference_batch_size = getattr(self.args, "inference_batch_size", None) + if inference_batch_size is not None: + engine_kwargs.setdefault("max_batch_size", inference_batch_size) + + # Protected, set last so extra_args cannot override: + # - overlap scheduler off: the capture buffer is reused every forward; + # the overlap pipeline could launch the next forward before + # process_and_save reads it. + # - chunked prefill off: we need each request's full prefill in one + # forward so the patch's per-request token offsets stay contiguous. + engine_kwargs["disable_overlap_scheduler"] = True + engine_kwargs["enable_chunked_prefill"] = False + + self._engine = LLM( + model=self.args.target_model_path, + backend="pytorch", + tensor_parallel_size=tp_size, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + speculative_config=spec_config, + **engine_kwargs, + ) + logger.info( + f"TrtllmEngine rank {self.rank}: LLM constructed with " + f"layers_to_capture={sorted(layers_to_capture)}" + ) + + def _normalize_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + if input_ids.dim() == 2 and input_ids.shape[0] == 1: + return input_ids.squeeze(0) + if input_ids.dim() == 1: + return input_ids + raise ValueError(f"Unexpected input_ids shape: {input_ids.shape}") + + def generate( + self, + data_id: str | list[str], + input_ids_ref: ray.ObjectRef | list[torch.Tensor] | None = None, + packed_loss_mask_list: list[str | None] | None = None, + formatted_prompts: list[str] | None = None, + return_last_hidden_states: bool = False, + return_logits: bool = True, + multimodal_inputs: list[dict] | None = None, + ) -> list[dict[str, Any]]: + """Run prefill and return Mooncake keys + tensor metadata. + + Hidden states are stored to Mooncake by the patched resource manager; + here we reconstruct each key from ``RequestOutput.request_id`` (which + equals the internal ``py_request_id`` the patch keyed on). + """ + if self._engine is None: + raise RuntimeError("TrtllmEngine not initialized. Call init() first.") + + if (input_ids_ref is None) == (formatted_prompts is None): + raise ValueError("Exactly one of input_ids_ref or formatted_prompts must be set") + + if multimodal_inputs is not None and any(m for m in multimodal_inputs): + raise NotImplementedError( + "TrtllmEngine does not support multimodal inputs yet." + ) + + use_prompts = formatted_prompts is not None + if use_prompts: + batch_size = len(formatted_prompts) + inputs: list = list(formatted_prompts) + else: + if isinstance(input_ids_ref, ray.ObjectRef): + input_ids_list = ray.get(input_ids_ref) + else: + input_ids_list = input_ids_ref + if input_ids_list is None: + raise ValueError("input_ids_ref resolved to None") + batch_size = len(input_ids_list) + inputs = [ + {"prompt_token_ids": self._normalize_input_ids(ids).tolist()} + for ids in input_ids_list + ] + + if isinstance(data_id, str): + data_ids = [f"{data_id}_{i}" for i in range(batch_size)] + elif len(data_id) == batch_size: + data_ids = data_id + else: + raise ValueError( + f"data_id length {len(data_id)} does not match batch size {batch_size}" + ) + + packed_loss_mask_map: dict[str, str | None] = {} + if packed_loss_mask_list is not None: + for i, did in enumerate(data_ids): + if i < len(packed_loss_mask_list): + packed_loss_mask_map[did] = packed_loss_mask_list[i] + + # Prefill-only: SaveHiddenStates forces max_new_tokens=1 internally, but + # we set it here too to avoid allocating decode resources. + sampling_params = SamplingParams(max_tokens=1) + + outputs = self._engine.generate(inputs, sampling_params, use_tqdm=False) + + results: list[dict[str, Any]] = [] + for i, output in enumerate(outputs): + did = data_ids[i] + seq_len = len(output.prompt_token_ids) + mooncake_key = _sanitize_mooncake_key(str(output.request_id)) + + result: dict[str, Any] = { + "mooncake_key": mooncake_key, + "tensor_shapes": self._get_tensor_shapes(seq_len), + "tensor_dtypes": self._get_tensor_dtypes(), + "data_id": did, + "seq_len": seq_len, + "input_ids_list": list(output.prompt_token_ids), + } + packed_loss_mask = packed_loss_mask_map.get(did) + if packed_loss_mask is not None: + result["packed_loss_mask"] = packed_loss_mask + results.append(result) + + logger.debug( + f"TrtllmEngine rank {self.rank}: generated {len(results)} mooncake results " + f"for data_ids={data_ids}" + ) + return results + + def health_check(self, timeout: float = 5.0) -> bool: + return self._engine is not None + + def shutdown(self) -> None: + if self._engine is not None: + try: + self._engine.shutdown() + except Exception as e: + logger.warning(f"TrtllmEngine rank {self.rank}: error during shutdown: {e}") + finally: + self._engine = None + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info(f"TrtllmEngine rank {self.rank}: shutdown complete") + + def get_status(self) -> dict: + return { + "rank": self.rank, + "initialized": self._engine is not None, + "base_gpu_id": self.base_gpu_id, + "hidden_size": self._hidden_size, + } + + def _get_hidden_size_from_engine(self) -> int: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained( + self.args.target_model_path, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + ) + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None: + text_config = getattr(config, "text_config", None) + if text_config is not None: + hidden_size = getattr(text_config, "hidden_size", None) + if hidden_size is None: + raise ValueError( + f"Could not determine hidden_size from model config: {self.args.target_model_path}" + ) + return hidden_size + + def _get_tensor_shapes(self, seq_len: int) -> dict: + """Shapes of the tensors the patch stored to Mooncake (no batch dim). + + The patch writes aux layers concatenated (``num_aux_layers * H``) as + ``hidden_states`` and the final post-norm state (``H``) as + ``last_hidden_states`` — matching the sglang engine's layout. + """ + if self._hidden_size is None: + raise ValueError( + f"TrtllmEngine rank {self.rank}: hidden_size not initialized. Call init() first." + ) + hidden_size = self._hidden_size + num_aux_layers = len(self.aux_hidden_state_layer_ids) + concat_hidden_size = num_aux_layers * hidden_size + + shapes = { + "hidden_states": (seq_len, concat_hidden_size), + "input_ids": (seq_len,), + } + if self._store_last_hidden_states: + shapes["last_hidden_states"] = (seq_len, hidden_size) + return shapes + + def _get_tensor_dtypes(self) -> dict: + dtypes = { + "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, + "input_ids": torch.long, + } + if self._store_last_hidden_states: + dtypes["last_hidden_states"] = HIDDEN_STATES_STORAGE_DTYPE + return dtypes From 6b020185d634aca39947b5ce88484cffc9229ab1 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 16 Jun 2026 00:01:57 +0000 Subject: [PATCH 03/28] Extend factory and config class for trtllm Signed-off-by: chungen04 --- torchspec/config/inference_config.py | 30 ++++++++++ torchspec/config/train_config.py | 1 + torchspec/inference/factory.py | 86 +++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index f11f2f7..63678f6 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -105,6 +105,35 @@ class VllmConfig: extra_args: Dict[str, Any] = field(default_factory=dict) +@dataclass +class TrtllmConfig: + """Essential TensorRT-LLM engine configuration. + + Wraps TRT-LLM's PyTorch-backend ``LLM`` in ``SaveHiddenStates`` mode; the + TorchSpec patch redirects captured hidden states to Mooncake. Only fields + TorchSpec explicitly uses are listed; any other ``LLM`` kwarg can be passed + via ``extra_args``. + + Single-node tensor parallelism only (nnodes must be 1); multi-node TP is + not yet wired up. + """ + + # Parallelism (TP degree is derived from inference_num_gpus_per_engine). + tp_size: int = 8 + pp_size: int = 1 + nnodes: int = 1 + + # KV-cache memory fraction (TRT-LLM's KvCacheConfig.free_gpu_memory_fraction). + mem_fraction_static: Optional[float] = 0.8 + + # TRT-LLM model build + load can be slow; give init a generous timeout. + init_timeout: int = 600 + + # Passthrough: forwarded as-is to the TRT-LLM LLM constructor + # (e.g. attn_backend, max_num_tokens, dtype, ...). + extra_args: Dict[str, Any] = field(default_factory=dict) + + @dataclass class InferenceConfig: aux_hidden_states_layers: Optional[list] = None @@ -122,6 +151,7 @@ class InferenceConfig: store_last_hidden_states: bool = True sglang: SGLangConfig = field(default_factory=SGLangConfig) vllm: VllmConfig = field(default_factory=VllmConfig) + trtllm: TrtllmConfig = field(default_factory=TrtllmConfig) def resolve_last_hidden_states_prenorm(self) -> bool: """Whether last_hidden_states from the engine are pre-norm. diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 06a5bd1..2bcfc2f 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -298,6 +298,7 @@ def load_config( "mooncake": "mooncake_", "sglang": "sglang_", "vllm": "vllm_", + "trtllm": "trtllm_", } diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index 58955c1..8ff298b 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -42,7 +42,7 @@ def create_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl", "vllm"): + if engine_type not in ("hf", "sgl", "vllm", "trtllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Using {engine_type} engine for inference") @@ -74,7 +74,7 @@ def prepare_inference_engines(args, inference_pg, mooncake_config, engine_group: """ engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("hf", "sgl", "vllm"): + if engine_type not in ("hf", "sgl", "vllm", "trtllm"): raise ValueError(f"Unknown inference_engine_type: {engine_type}") logger.info(f"Preparing {engine_type} inference engines...") @@ -83,6 +83,10 @@ def prepare_inference_engines(args, inference_pg, mooncake_config, engine_group: engines, init_refs = _prepare_hf_engines(args, inference_pg, mooncake_config, engine_group) elif engine_type == "sgl": engines, init_refs = _prepare_sgl_engines(args, inference_pg, mooncake_config, engine_group) + elif engine_type == "trtllm": + engines, init_refs = _prepare_trtllm_engines( + args, inference_pg, mooncake_config, engine_group + ) else: engines, init_refs = _prepare_vllm_engines( args, inference_pg, mooncake_config, engine_group @@ -97,7 +101,7 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: Args: args: Configuration arguments. pg: Placement group tuple (pg, reordered_bundle_indices, reordered_gpu_ids). - engine_type: Engine type ("hf", "sgl", or "vllm"). + engine_type: Engine type ("hf", "sgl", "vllm", or "trtllm"). mooncake_config: MooncakeConfig object. Returns: @@ -109,6 +113,8 @@ def init_engines(args, pg, engine_type: str, mooncake_config=None, engine_group: return _init_sgl_engines(args, pg, mooncake_config, engine_group) elif engine_type == "vllm": return _init_vllm_engines(args, pg, mooncake_config, engine_group) + elif engine_type == "trtllm": + return _init_trtllm_engines(args, pg, mooncake_config, engine_group) else: raise ValueError(f"Unknown engine_type: {engine_type}") @@ -429,6 +435,80 @@ def _init_vllm_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> return head_engines +def _prepare_trtllm_engines( + args, pg, mooncake_config=None, engine_group: int = 0 +) -> tuple[list, list]: + """Create TensorRT-LLM engine actors and fire init calls without waiting. + + Single-node only: each engine owns a contiguous block of + ``inference_num_gpus_per_engine`` GPUs and runs TRT-LLM's own MPI workers + across them (TP degree = num_gpus_per_engine). Unlike sgl/vllm there is no + cross-node dist-init negotiation or port pre-allocation — TRT-LLM manages + worker bring-up internally. + + Returns: + Tuple of (engines, init_handles). + """ + nnodes = getattr(args, "trtllm_nnodes", 1) + if nnodes > 1: + raise NotImplementedError( + f"trtllm backend supports single-node TP only (trtllm_nnodes={nnodes})." + ) + + num_gpus_total = getattr(args, "inference_num_gpus", 1) + gpus_per_engine = getattr(args, "inference_num_gpus_per_engine", 1) + num_engines = num_gpus_total // gpus_per_engine + + logger.info( + f"Initializing {num_engines} TensorRT-LLM engines " + f"({gpus_per_engine} GPU(s) each, tp_size={gpus_per_engine})" + ) + + from torchspec.inference.engine.trtllm_engine import TrtllmEngine + + pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg + + TrtllmRayActor = ray.remote(TrtllmEngine) + env_vars = get_torchspec_env_vars() + + engines = [] + for i in range(num_engines): + bundle_offset = i * gpus_per_engine + base_gpu_id = int(reordered_gpu_ids[bundle_offset]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg_obj, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[bundle_offset], + ) + + engine = TrtllmRayActor.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=scheduling_strategy, + runtime_env={"env_vars": env_vars}, + ).remote( + args=args, + rank=i, + base_gpu_id=base_gpu_id, + num_gpus_per_engine=gpus_per_engine, + node_rank=0, + engine_group=engine_group, + ) + engines.append(engine) + + init_handles = [engine.init.remote(mooncake_config=mooncake_config) for engine in engines] + return engines, init_handles + + +def _init_trtllm_engines(args, pg, mooncake_config=None, engine_group: int = 0) -> list: + """Initialize TensorRT-LLM engines with Ray placement groups (blocking).""" + engines, init_handles = _prepare_trtllm_engines(args, pg, mooncake_config, engine_group) + init_timeout = getattr(args, "trtllm_init_timeout", 600) + _wait_for_init(init_handles, "Trtllm", timeout=init_timeout) + return engines + + def _create_and_init_actors( args, pg, From 7c66102928f3c4725439c8c22f507f2a03ed37dc Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 16 Jun 2026 00:08:41 +0000 Subject: [PATCH 04/28] dockerfile for trtllm Signed-off-by: chungen04 --- docker/justfile | 3 ++- docker/trtllm/v1.3.0rc18/Dockerfile | 40 +++++++++++++++++++++++++++++ pyproject.toml | 4 +++ 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 docker/trtllm/v1.3.0rc18/Dockerfile diff --git a/docker/justfile b/docker/justfile index 9d65458..82cbbe2 100755 --- a/docker/justfile +++ b/docker/justfile @@ -1,10 +1,11 @@ BACKEND := env("BACKEND", "sglang") SGLANG_VERSION := env("SGLANG_VERSION", "v0.5.8.post1") VLLM_VERSION := env("VLLM_VERSION", "v0.22.1") +TRTLLM_VERSION := env("TRTLLM_VERSION", "v1.3.0rc18") IMAGE_REPO := env("IMAGE_REPO", "ghcr.io/torchspec-project/torchspec") IMAGE_TAG := env("IMAGE_TAG", "") -_dockerfile := if BACKEND == "vllm" { "docker/vllm/" + VLLM_VERSION + "/Dockerfile" } else { "docker/sglang/" + SGLANG_VERSION + "/Dockerfile" } +_dockerfile := if BACKEND == "vllm" { "docker/vllm/" + VLLM_VERSION + "/Dockerfile" } else if BACKEND == "trtllm" { "docker/trtllm/" + TRTLLM_VERSION + "/Dockerfile" } else { "docker/sglang/" + SGLANG_VERSION + "/Dockerfile" } build: ARG_TAG_POSTFIX="${ARG_TAG_POSTFIX:-""}" ARG_BUILD_EXTRA_ARGS="" just _build-only diff --git a/docker/trtllm/v1.3.0rc18/Dockerfile b/docker/trtllm/v1.3.0rc18/Dockerfile new file mode 100644 index 0000000..668eb4b --- /dev/null +++ b/docker/trtllm/v1.3.0rc18/Dockerfile @@ -0,0 +1,40 @@ +ARG TRTLLM_IMAGE=nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc18 +FROM ${TRTLLM_IMAGE} AS trtllm + +WORKDIR /root/ + +RUN apt-get update && \ + apt-get install -y --no-install-recommends nvtop rsync dnsutils && \ + rm -rf /var/lib/apt/lists/* + +# Patch the installed tensorrt_llm package so SaveHiddenStates mode streams +# captured hidden states to Mooncake instead of writing .pt files to disk. +# Resolve the site-packages dir dynamically (so this is independent of the +# image's Python version) and apply with -p1 to strip the diff's a/ prefix. +COPY patches/trtllm/v1.3.0rc18/*.patch /tmp/patches/ +RUN TRTLLM_SITE_PACKAGES=$(python3 -c "import os, tensorrt_llm; print(os.path.dirname(os.path.dirname(tensorrt_llm.__file__)))") && \ + cd "$TRTLLM_SITE_PACKAGES" && \ + for p in /tmp/patches/*.patch; do patch -p1 < "$p"; done && \ + rm -rf /tmp/patches + +# Fail the build early if the patch did not land where we expect. +RUN python3 -c "import inspect, tensorrt_llm._torch.speculative.save_hidden_state as m; \ + src = inspect.getsource(m); \ + assert 'TORCHSPEC_TRTLLM_MOONCAKE' in src and '_process_and_store_mooncake' in src, \ + 'TorchSpec Mooncake patch was not applied to save_hidden_state.py'" + +COPY . /root/torchspec +RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" + +# TorchSpec's dependency pulls the generic CUDA 12 Mooncake wheel. Replace it +# with the CUDA 13 wheel that matches this image family. +RUN pip uninstall -y mooncake-transfer-engine || true && \ + pip install --no-cache-dir --no-deps --force-reinstall \ + mooncake-transfer-engine-cuda13==0.3.11.post1 + +RUN chmod 755 /usr/local/lib/python3.12/dist-packages/mooncake/mooncake_master || true +RUN if [ -f /usr/local/lib/python3.12/dist-packages/mooncake/cli.py ]; then \ + sed -i 's/os.chmod(bin_path, 0o755)/pass/' /usr/local/lib/python3.12/dist-packages/mooncake/cli.py; \ + fi + +WORKDIR /root/torchspec diff --git a/pyproject.toml b/pyproject.toml index d97974d..1906d42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,10 @@ vllm = [ "vllm>=0.18.0", ] +trtllm = [ + "tensorrt-llm>=1.3.0rc18", +] + fa = [ "flash-attn-4>=4.0.0b7", "nvidia-cutlass-dsl>=4.4.2", From 255a9f9ec9e449cc03373e87d980dd329928fa53 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 16 Jun 2026 00:15:14 +0000 Subject: [PATCH 05/28] Add qwen3 8b example Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 79 ++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 configs/trtllm_qwen3_8b.yaml diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml new file mode 100644 index 0000000..f9a6c7d --- /dev/null +++ b/configs/trtllm_qwen3_8b.yaml @@ -0,0 +1,79 @@ +# Configuration for train_entry.py with TensorRT-LLM Engine inference (nested config format) +# +# GPU allocation: +# - 2 GPUs for inference (duplicate mode: each engine has full model copy) +# - 2 GPUs for training (DP/FSDP: model sharded across 2 GPUs) +# - Total: 4 GPUs +# +# Installation: +# Use the TensorRT-LLM docker image (docker/trtllm/v1.3.0rc18/Dockerfile), +# which ships tensorrt_llm patched for Mooncake hidden-state capture. +# For a local install: pip install -e ".[trtllm]" +# +# Usage: +# python -m torchspec.train_entry --config configs/trtllm_qwen3_8b.yaml +# +# Note: Uses TensorRT-LLM's SaveHiddenStates speculative mode; the TorchSpec +# patch redirects captured aux + final hidden states to Mooncake. +# Single-node tensor parallelism only. + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + +dataset: + train_data_path: ../examples/data/sample_conversations.jsonl + eval_data_path: ../examples/data/eval_conversations.jsonl + eval_interval: 100 + chat_template: qwen + prompt_key: conversations + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 1 + learning_rate: 1e-4 + max_concurrent_batches: 1 + max_grad_norm: 0.5 + max_seq_length: 16384 + num_epochs: 1 + seed: 42 + training_num_gpus_per_node: 2 + training_num_nodes: 1 + ttt_length: 7 + save_per_epoch: true + warmup_ratio: 0.015 + +inference: + inference_engine_type: trtllm + inference_num_gpus: 2 + inference_num_gpus_per_engine: 2 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + trtllm: + tp_size: 2 + # KV-cache memory fraction. Kept below TRT-LLM's 0.9 default to leave room + # for the SaveHiddenStates capture buffer, which the KV profiler does not + # account for. + mem_fraction_static: 0.7 + extra_args: + # Any extra TensorRT-LLM LLM kwarg; e.g. cap the per-iteration token budget. + max_num_tokens: 8192 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 32GB + local_buffer_size: 4GB + +output_dir: ./outputs/qwen3-8b-single-node +cache_dir: ./cache +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false From 70e54192460128cf614a7eb760456d89b31b787e Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 16 Jun 2026 00:30:33 +0000 Subject: [PATCH 06/28] Modified qwen3 8b example Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index f9a6c7d..60883f9 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -15,7 +15,6 @@ # # Note: Uses TensorRT-LLM's SaveHiddenStates speculative mode; the TorchSpec # patch redirects captured aux + final hidden states to Mooncake. -# Single-node tensor parallelism only. model: target_model_path: Qwen/Qwen3-8B From 8d266f28821e0ddd492e2fc33c29cdcf1b55132b Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 16 Jun 2026 00:31:09 +0000 Subject: [PATCH 07/28] marked trtllm support in README.md Signed-off-by: chungen04 --- README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cfa1a6e..7e6531a 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ TorchSpec streams hidden states from inference engines into training workers. |---------|--------------|--------| | [vLLM](https://github.com/vllm-project/vllm) | First-class | Available | | [TokenSpeed](https://github.com/lightseekorg/tokenspeed) | First-class | In progress | +| [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) | First-class | Available | | [SGLang](https://github.com/sgl-project/sglang) | Best community effort | Available | | [HuggingFace Transformers](https://github.com/huggingface/transformers) | Best community effort | Available | @@ -121,7 +122,17 @@ pip install -e ".[fa]" ./examples/qwen3-8b-single-node/run.sh ``` -TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction. +**TensorRT-LLM** + +Run inside the TensorRT-LLM image (`docker/trtllm/v1.3.0rc18/Dockerfile`), which ships `tensorrt_llm` pre-patched for Mooncake hidden-state capture: + +```bash +./examples/qwen3-8b-single-node/run.sh --config configs/trtllm_qwen3_8b.yaml +``` + +Single-node tensor parallelism only for now (multi-node TP is not yet wired up). + +TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction. For TensorRT-LLM, TorchSpec builds on its native **SaveHiddenStates** speculative mode and applies a small patch that redirects the captured aux + final hidden states to Mooncake instead of writing them to disk. ## Examples From 2a377d0df43f6582852ed9a7a44233021824c5d5 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Wed, 17 Jun 2026 21:34:46 +0000 Subject: [PATCH 08/28] Apply trtllm patch import-free at build time tensorrt_llm's package init loads libcuda.so.1, which is absent during the image build. Apply the patch in the fixed python3.12 dist-packages path and verify it landed by grepping the patched file instead of importing the module. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- docker/trtllm/v1.3.0rc18/Dockerfile | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/docker/trtllm/v1.3.0rc18/Dockerfile b/docker/trtllm/v1.3.0rc18/Dockerfile index 668eb4b..a973fbf 100644 --- a/docker/trtllm/v1.3.0rc18/Dockerfile +++ b/docker/trtllm/v1.3.0rc18/Dockerfile @@ -9,19 +9,21 @@ RUN apt-get update && \ # Patch the installed tensorrt_llm package so SaveHiddenStates mode streams # captured hidden states to Mooncake instead of writing .pt files to disk. -# Resolve the site-packages dir dynamically (so this is independent of the -# image's Python version) and apply with -p1 to strip the diff's a/ prefix. +# tensorrt_llm lives in dist-packages (python 3.12 in this release image, same +# location the vLLM image uses). Do NOT import tensorrt_llm at build time — its +# package init loads libcuda.so.1, which is absent until container runtime. +# Apply with -p1 to strip the diff's a/ prefix. COPY patches/trtllm/v1.3.0rc18/*.patch /tmp/patches/ -RUN TRTLLM_SITE_PACKAGES=$(python3 -c "import os, tensorrt_llm; print(os.path.dirname(os.path.dirname(tensorrt_llm.__file__)))") && \ - cd "$TRTLLM_SITE_PACKAGES" && \ +RUN cd /usr/local/lib/python3.12/dist-packages && \ for p in /tmp/patches/*.patch; do patch -p1 < "$p"; done && \ rm -rf /tmp/patches -# Fail the build early if the patch did not land where we expect. -RUN python3 -c "import inspect, tensorrt_llm._torch.speculative.save_hidden_state as m; \ - src = inspect.getsource(m); \ - assert 'TORCHSPEC_TRTLLM_MOONCAKE' in src and '_process_and_store_mooncake' in src, \ - 'TorchSpec Mooncake patch was not applied to save_hidden_state.py'" +# Fail the build early if the patch did not land. Grep the file directly +# (import-free, since build-time has no CUDA driver). +RUN F=/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/speculative/save_hidden_state.py && \ + grep -q "TORCHSPEC_TRTLLM_MOONCAKE" "$F" && \ + grep -q "_process_and_store_mooncake" "$F" || \ + { echo "ERROR: TorchSpec Mooncake patch was not applied to $F"; exit 1; } COPY . /root/torchspec RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" From b6c4b5755d551002e2c6110ef8094459d9234d20 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 19 Jun 2026 00:47:03 +0000 Subject: [PATCH 09/28] docker(trtllm): forward Mooncake/TorchSpec env to MPI workers TRT-LLM's MpiPoolSession._start_mpi_pool only forwards env vars prefixed TRTLLM/TLLM to its spawned MPI workers. The SaveHiddenStates Mooncake redirect gates on TORCHSPEC_TRTLLM_MOONCAKE and reads the MOONCAKE_* connection params, none of which reached the workers, so hidden-state capture silently fell back to disk mode and the trainer timed out waiting for keys ("batch_get_buffer missing keys"). Add mpi_session.patch forwarding TORCHSPEC*/MOONCAKE*/MC_* prefixes (applied by the existing patch loop) plus a build-time validation grep. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- docker/trtllm/v1.3.0rc18/Dockerfile | 9 ++++++++- patches/trtllm/v1.3.0rc18/mpi_session.patch | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 patches/trtllm/v1.3.0rc18/mpi_session.patch diff --git a/docker/trtllm/v1.3.0rc18/Dockerfile b/docker/trtllm/v1.3.0rc18/Dockerfile index a973fbf..0991edc 100644 --- a/docker/trtllm/v1.3.0rc18/Dockerfile +++ b/docker/trtllm/v1.3.0rc18/Dockerfile @@ -18,13 +18,20 @@ RUN cd /usr/local/lib/python3.12/dist-packages && \ for p in /tmp/patches/*.patch; do patch -p1 < "$p"; done && \ rm -rf /tmp/patches -# Fail the build early if the patch did not land. Grep the file directly +# Fail the build early if the patches did not land. Grep the files directly # (import-free, since build-time has no CUDA driver). RUN F=/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/speculative/save_hidden_state.py && \ grep -q "TORCHSPEC_TRTLLM_MOONCAKE" "$F" && \ grep -q "_process_and_store_mooncake" "$F" || \ { echo "ERROR: TorchSpec Mooncake patch was not applied to $F"; exit 1; } +# The MPI worker pool only forwards TRTLLM*/TLLM* env vars by default; the patch +# below also forwards TORCHSPEC*/MOONCAKE*/MC_* so the Mooncake redirect flag and +# connection params reach the workers running SaveHiddenStates. +RUN F=/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/mpi_session.py && \ + grep -q "TORCHSPEC" "$F" && grep -q "MOONCAKE" "$F" || \ + { echo "ERROR: TorchSpec MPI env-forwarding patch was not applied to $F"; exit 1; } + COPY . /root/torchspec RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" diff --git a/patches/trtllm/v1.3.0rc18/mpi_session.patch b/patches/trtllm/v1.3.0rc18/mpi_session.patch new file mode 100644 index 0000000..8c0bfe2 --- /dev/null +++ b/patches/trtllm/v1.3.0rc18/mpi_session.patch @@ -0,0 +1,13 @@ +diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py +index 2e1eeb2..7d1bca8 100644 +--- a/tensorrt_llm/llmapi/mpi_session.py ++++ b/tensorrt_llm/llmapi/mpi_session.py +@@ -174,6 +174,8 @@ class MpiPoolSession(MpiSession): + key: value + for key, value in os.environ.items() + if key.startswith("TRTLLM") or key.startswith("TLLM") ++ or key.startswith("TORCHSPEC") or key.startswith("MOONCAKE") ++ or key.startswith("MC_") + } + self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, + path=sys.path, From 065c18c20eeecc1ff57d09560d9759724cb01656 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 19 Jun 2026 00:47:21 +0000 Subject: [PATCH 10/28] docker(trtllm): key SaveHiddenStates Mooncake store on client_id The patch stored hidden states under request.py_request_id, the backend runtime request id which is bumped by JIT-warmup and internal requests. But the TorchSpec engine reconstructs Mooncake keys from RequestOutput.request_id == GenerationRequest.id, which TRT-LLM sets to the client_id assigned by the proxy (_get_next_client_id). The two id spaces are mapped via _client_id_to_request_id and differ, so the stored keys were systematically offset from the keys the trainer requested, surfacing as "Size mismatch for hidden_states" (data shifted between keys). Key on request.py_client_id, falling back to py_request_id if unset. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/trtllm.patch | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch index be5f682..8da10ab 100644 --- a/patches/trtllm/v1.3.0rc18/trtllm.patch +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -1,5 +1,5 @@ diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py -index 05cf2610..b8e8145a 100644 +index 05cf261..f20431a 100644 --- a/tensorrt_llm/_torch/speculative/save_hidden_state.py +++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py @@ -1,10 +1,12 @@ @@ -66,7 +66,7 @@ index 05cf2610..b8e8145a 100644 for request in sorted( scheduled_requests.context_requests, key=lambda r: -@@ -83,6 +113,97 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): +@@ -83,6 +113,107 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): self._write_to_file() self._iter += 1 @@ -144,7 +144,17 @@ index 05cf2610..b8e8145a 100644 + dtype=torch.long, + device=buf.device) + -+ key = _sanitize_mooncake_key(str(request.py_request_id)) ++ # Key on the CLIENT id, not the backend py_request_id. The TorchSpec ++ # engine reconstructs keys from RequestOutput.request_id, which TRT-LLM ++ # sets to GenerationRequest.id == the client_id assigned by the proxy ++ # (_get_next_client_id). The backend py_request_id is a separate ++ # counter (see _client_id_to_request_id in base_worker) bumped by JIT ++ # warmup and internal requests, so keying on it produces a systematic ++ # offset and crossed keys. py_client_id mirrors the client id here. ++ client_key_id = getattr(request, "py_client_id", None) ++ if client_key_id is None: ++ client_key_id = request.py_request_id ++ key = _sanitize_mooncake_key(str(client_key_id)) + try: + store.put( + key=key, @@ -156,7 +166,7 @@ index 05cf2610..b8e8145a 100644 + except Exception: + logger.error( + f"SaveHiddenStates: failed to store request " -+ f"{request.py_request_id} to Mooncake", ++ f"{client_key_id} to Mooncake", + exc_info=True) + + store.flush() From fb619e84457dedacd54b1c09f5f2496b802c84cd Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 19 Jun 2026 00:47:28 +0000 Subject: [PATCH 11/28] trtllm: disable KV block reuse for hidden-state capture SaveHiddenStates only captures hidden states for tokens that actually run a forward pass. With enable_block_reuse on (the TRT-LLM default), prompts that share a prefix reuse cached KV blocks and skip the forward for the shared tokens, so fewer tokens are captured than the prompt length. Always construct KvCacheConfig with enable_block_reuse=False and enable_partial_reuse=False (previously the config was only built when a mem fraction was set, leaving reuse on otherwise). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- torchspec/inference/engine/trtllm_engine.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index 5d58dfd..db529f8 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -277,10 +277,19 @@ def _init_engine(self, tp_size: int, mem_fraction: float | None) -> None: extra = {k: v for k, v in extra.items() if k not in _PROTECTED_ENGINE_KEYS} engine_kwargs.update(extra) + # Block reuse must be OFF for hidden-state capture: when two prompts + # share a prefix, TRT-LLM reuses the cached KV blocks and only runs a + # forward pass over the new suffix tokens, so SaveHiddenStates captures + # hidden states for fewer tokens than the prompt length. The trainer + # then sees a shape mismatch (engine reports full seq_len, store holds + # only the recomputed tokens). Disabling reuse forces a full prefill. + kv_cache_kwargs: dict[str, Any] = { + "enable_block_reuse": False, + "enable_partial_reuse": False, + } if mem_fraction is not None: - engine_kwargs["kv_cache_config"] = KvCacheConfig( - free_gpu_memory_fraction=mem_fraction - ) + kv_cache_kwargs["free_gpu_memory_fraction"] = mem_fraction + engine_kwargs["kv_cache_config"] = KvCacheConfig(**kv_cache_kwargs) max_seq_length = getattr(self.args, "max_seq_length", None) if max_seq_length: From 31b8615db89a693a01e20398124869b2c5d99bc2 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 19 Jun 2026 00:47:35 +0000 Subject: [PATCH 12/28] config(trtllm): raise max_num_tokens to match max_seq_length SaveHiddenStates runs each prompt as a single prefill (chunked prefill is disabled), so a prompt longer than max_num_tokens is rejected outright ("sum of prompt length ... should not exceed max_num_tokens") and the sample is dropped. With max_num_tokens=8192 but max_seq_length=16384, samples in that range were silently lost during training. Raise to 16384. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index 60883f9..62432c2 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -59,7 +59,11 @@ inference: mem_fraction_static: 0.7 extra_args: # Any extra TensorRT-LLM LLM kwarg; e.g. cap the per-iteration token budget. - max_num_tokens: 8192 + # Must be >= training.max_seq_length: SaveHiddenStates runs each prompt in a + # single prefill (chunked prefill is disabled), so a prompt longer than + # max_num_tokens is rejected outright ("sum of prompt length ... should not + # exceed max_num_tokens") and that sample is dropped. + max_num_tokens: 16384 mooncake: master_server_address: null From a01d04f3b32aae0551d0c50823f89756e768f7ea Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 19 Jun 2026 17:45:42 +0000 Subject: [PATCH 13/28] style(trtllm): apply ruff format to trtllm_engine Collapse a NotImplementedError call onto one line (fits in the 100-char limit) so `ruff format --check` passes. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- torchspec/inference/engine/trtllm_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index db529f8..c7927f4 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -351,9 +351,7 @@ def generate( raise ValueError("Exactly one of input_ids_ref or formatted_prompts must be set") if multimodal_inputs is not None and any(m for m in multimodal_inputs): - raise NotImplementedError( - "TrtllmEngine does not support multimodal inputs yet." - ) + raise NotImplementedError("TrtllmEngine does not support multimodal inputs yet.") use_prompts = formatted_prompts is not None if use_prompts: From 99a16020b1e78476d19909b53a3982a9204c8d3f Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 23 Jun 2026 22:19:06 +0000 Subject: [PATCH 14/28] trtllm: fold mpi_session patch into trtllm.patch and relax batch cap Merge the standalone mpi_session.patch (forwarding TORCHSPEC*/MOONCAKE*/MC_* env to MPI workers) into the main trtllm.patch and drop the separate file. Also relax the upstream max_batch_size=1 cap that TRT-LLM forces whenever SaveHiddenStatesDecodingConfig is active: when TORCHSPEC_TRTLLM_MOONCAKE is set, keep the configured max_batch_size so batched prefill capture works. The Mooncake resource manager de-interleaves the packed multi-request buffer, so the offset-0 single-request constraint no longer applies. Overlap scheduler and CUDA graphs stay disabled (shared capture buffer is overwritten every forward and must be read first). Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/mpi_session.patch | 13 ------ patches/trtllm/v1.3.0rc18/trtllm.patch | 44 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 13 deletions(-) delete mode 100644 patches/trtllm/v1.3.0rc18/mpi_session.patch diff --git a/patches/trtllm/v1.3.0rc18/mpi_session.patch b/patches/trtllm/v1.3.0rc18/mpi_session.patch deleted file mode 100644 index 8c0bfe2..0000000 --- a/patches/trtllm/v1.3.0rc18/mpi_session.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py -index 2e1eeb2..7d1bca8 100644 ---- a/tensorrt_llm/llmapi/mpi_session.py -+++ b/tensorrt_llm/llmapi/mpi_session.py -@@ -174,6 +174,8 @@ class MpiPoolSession(MpiSession): - key: value - for key, value in os.environ.items() - if key.startswith("TRTLLM") or key.startswith("TLLM") -+ or key.startswith("TORCHSPEC") or key.startswith("MOONCAKE") -+ or key.startswith("MC_") - } - self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, - path=sys.path, diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch index 8da10ab..ded48b4 100644 --- a/patches/trtllm/v1.3.0rc18/trtllm.patch +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -174,3 +174,47 @@ index 05cf261..f20431a 100644 def _process_request(self, request: LlmRequest, spec_metadata: "SaveHiddenStatesSpecMetadata") -> None: if local_mpi_rank() != 0: +diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py +index 2e1eeb2..7d1bca8 100644 +--- a/tensorrt_llm/llmapi/mpi_session.py ++++ b/tensorrt_llm/llmapi/mpi_session.py +@@ -174,6 +174,8 @@ class MpiPoolSession(MpiSession): + key: value + for key, value in os.environ.items() + if key.startswith("TRTLLM") or key.startswith("TLLM") ++ or key.startswith("TORCHSPEC") or key.startswith("MOONCAKE") ++ or key.startswith("MC_") + } + self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, + path=sys.path, +diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py +--- a/tensorrt_llm/llmapi/llm_args.py ++++ b/tensorrt_llm/llmapi/llm_args.py +@@ -4525,10 +4525,23 @@ + + if isinstance(self.speculative_config, + SaveHiddenStatesDecodingConfig): +- logger.warning( +- "SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None" +- ) +- self.max_batch_size = 1 ++ # TorchSpec: in Mooncake capture mode the SaveHiddenStates ++ # resource manager de-interleaves a packed multi-request buffer ++ # (save_hidden_state.py::_process_and_store_mooncake), so the ++ # upstream max_batch_size=1 cap -- which only protects the ++ # offset-0 single-request disk path -- is relaxed to allow ++ # batched prefill capture. Overlap scheduler and CUDA graphs ++ # stay disabled: the shared capture buffer is overwritten every ++ # forward and must be read before the next one runs. ++ if os.environ.get("TORCHSPEC_TRTLLM_MOONCAKE"): ++ logger.warning( ++ "SaveHiddenStatesDecodingConfig is active (TorchSpec Mooncake mode); keeping max_batch_size, disabling overlap scheduler, and setting cuda_graph_config to None" ++ ) ++ else: ++ logger.warning( ++ "SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None" ++ ) ++ self.max_batch_size = 1 + self.disable_overlap_scheduler = True + self.cuda_graph_config = None + self.speculative_config.max_draft_len = 1 From 9050871e9f9f270c89fff8faa853554d95da8fc1 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 23 Jun 2026 22:19:12 +0000 Subject: [PATCH 15/28] trtllm: unify tp_size override across backends in run.sh Derive the per-backend config block from inference_engine_type instead of hardcoding inference.sglang.tp_size. "sgl" maps to the "sglang" block; vllm/trtllm map 1:1. The same launcher now drives sglang, vllm, and trtllm with an identical inference layout for fair side-by-side runs. Signed-off-by: chungen04 --- examples/qwen3-8b-single-node/run.sh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index 6649200..fed3188 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -37,6 +37,15 @@ else CONFIG_FILE="$ROOT_DIR/configs/sglang_qwen3_8b.yaml" fi +# Per-backend tp_size override key, derived from the config's engine type. +# engine_type "sgl" lives under the "sglang" config block; vllm/trtllm match 1:1. +# This lets run.sh launch any backend with the same 2-GPU/tp=2 inference layout. +ENGINE_TYPE=$(grep -oE "inference_engine_type:[[:space:]]*[a-zA-Z]+" "$CONFIG_FILE" | awk '{print $2}') +case "$ENGINE_TYPE" in + sgl) TP_BLOCK=sglang ;; + *) TP_BLOCK="${ENGINE_TYPE:-sglang}" ;; +esac + IFS=',' read -ra GPU_ARRAY <<< "$CUDA_VISIBLE_DEVICES" TOTAL_GPUS=${#GPU_ARRAY[@]} @@ -56,14 +65,13 @@ echo "Local IP: $LOCAL_IP" echo "Extra args: $*" echo "==============================================" -# TODO: unify tp_size config across sglang/vllm backends python3 -m torchspec.train_entry \ --config "$CONFIG_FILE" \ training.training_num_gpus_per_node="$TRAIN_GPUS" \ inference.inference_num_gpus="$INFERENCE_GPUS" \ inference.inference_num_gpus_per_engine=2 \ inference.inference_num_gpus_per_node="$TOTAL_GPUS" \ - inference.sglang.tp_size=2 \ + inference.${TP_BLOCK}.tp_size=2 \ "$@" echo "==============================================" From 183f5a6c6a32a96a121b92e0113d38a196c447da Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 23 Jun 2026 22:19:18 +0000 Subject: [PATCH 16/28] config(trtllm): tune qwen3-8b for single-node single-GPU Switch the inference layout to 1 GPU / tp=1, drop the eval dataset, shrink the Mooncake global_segment_size to 16GB, and namespace cache_dir per example. Point the usage comment at the unified run.sh launcher. Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index 62432c2..3839265 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -10,8 +10,8 @@ # which ships tensorrt_llm patched for Mooncake hidden-state capture. # For a local install: pip install -e ".[trtllm]" # -# Usage: -# python -m torchspec.train_entry --config configs/trtllm_qwen3_8b.yaml +# Usage (same launcher as the sglang example, for a fair side-by-side): +# ./examples/qwen3-8b-single-node/run.sh configs/trtllm_qwen3_8b.yaml # # Note: Uses TensorRT-LLM's SaveHiddenStates speculative mode; the TorchSpec # patch redirects captured aux + final hidden states to Mooncake. @@ -22,8 +22,6 @@ model: dataset: train_data_path: ../examples/data/sample_conversations.jsonl - eval_data_path: ../examples/data/eval_conversations.jsonl - eval_interval: 100 chat_template: qwen prompt_key: conversations @@ -45,14 +43,14 @@ training: inference: inference_engine_type: trtllm - inference_num_gpus: 2 - inference_num_gpus_per_engine: 2 - inference_num_gpus_per_node: 4 - max_sample_pool_size: 64 - inference_buffer_threshold: 32 + inference_num_gpus: 1 + inference_num_gpus_per_engine: 1 + inference_num_gpus_per_node: 1 + max_sample_pool_size: 64 # Max samples in controller pool + inference_buffer_threshold: 32 # Fetch prompts when buffer < threshold inference_batch_size: 8 trtllm: - tp_size: 2 + tp_size: 1 # KV-cache memory fraction. Kept below TRT-LLM's 0.9 default to leave room # for the SaveHiddenStates capture buffer, which the KV profiler does not # account for. @@ -69,11 +67,11 @@ mooncake: master_server_address: null metadata_server: null protocol: tcp - global_segment_size: 32GB + global_segment_size: 16GB local_buffer_size: 4GB output_dir: ./outputs/qwen3-8b-single-node -cache_dir: ./cache +cache_dir: ./cache/qwen3-8b-single-node model_download_dir: null debug: From 811d0282bb4beb754403436a07f34eee9e557edb Mon Sep 17 00:00:00 2001 From: chungen04 Date: Tue, 23 Jun 2026 22:31:44 +0000 Subject: [PATCH 17/28] align(trtllm): match remote README launch form and Dockerfile patch step Adopt the positional run.sh launch form in the README (vLLM + TRT-LLM examples) and drop the patch-verify grep steps from the trtllm Dockerfile, matching origin/chungen/trtllm. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- README.md | 4 ++-- docker/trtllm/v1.3.0rc18/Dockerfile | 14 -------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 7e6531a..1202251 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ pip install -e ".[fa]" **vLLM** ```bash -./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml +./examples/qwen3-8b-single-node/run.sh configs/vllm_qwen3_8b.yaml ``` **SGLang** @@ -127,7 +127,7 @@ pip install -e ".[fa]" Run inside the TensorRT-LLM image (`docker/trtllm/v1.3.0rc18/Dockerfile`), which ships `tensorrt_llm` pre-patched for Mooncake hidden-state capture: ```bash -./examples/qwen3-8b-single-node/run.sh --config configs/trtllm_qwen3_8b.yaml +./examples/qwen3-8b-single-node/run.sh configs/trtllm_qwen3_8b.yaml ``` Single-node tensor parallelism only for now (multi-node TP is not yet wired up). diff --git a/docker/trtllm/v1.3.0rc18/Dockerfile b/docker/trtllm/v1.3.0rc18/Dockerfile index 0991edc..330ee44 100644 --- a/docker/trtllm/v1.3.0rc18/Dockerfile +++ b/docker/trtllm/v1.3.0rc18/Dockerfile @@ -18,20 +18,6 @@ RUN cd /usr/local/lib/python3.12/dist-packages && \ for p in /tmp/patches/*.patch; do patch -p1 < "$p"; done && \ rm -rf /tmp/patches -# Fail the build early if the patches did not land. Grep the files directly -# (import-free, since build-time has no CUDA driver). -RUN F=/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/speculative/save_hidden_state.py && \ - grep -q "TORCHSPEC_TRTLLM_MOONCAKE" "$F" && \ - grep -q "_process_and_store_mooncake" "$F" || \ - { echo "ERROR: TorchSpec Mooncake patch was not applied to $F"; exit 1; } - -# The MPI worker pool only forwards TRTLLM*/TLLM* env vars by default; the patch -# below also forwards TORCHSPEC*/MOONCAKE*/MC_* so the Mooncake redirect flag and -# connection params reach the workers running SaveHiddenStates. -RUN F=/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/mpi_session.py && \ - grep -q "TORCHSPEC" "$F" && grep -q "MOONCAKE" "$F" || \ - { echo "ERROR: TorchSpec MPI env-forwarding patch was not applied to $F"; exit 1; } - COPY . /root/torchspec RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" From 5ef2a1389b04be43463163144cedac7dfc4dae5c Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 17:42:18 +0000 Subject: [PATCH 18/28] fix(trtllm): capture hidden states in native context order SaveHiddenStates' Mooncake path sliced per-request hidden states in py_batch_idx (= py_seq_slot, an arbitrary KV slot) order, but the forward packs tokens in scheduled_requests.context_requests native order. Under multi-request batching the orders differ, so per-request slices read the wrong hidden states and corrupt a subset of captured samples. Iterate context_requests in native order so the slices align. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/trtllm.patch | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch index ded48b4..4a18a77 100644 --- a/patches/trtllm/v1.3.0rc18/trtllm.patch +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -66,7 +66,7 @@ index 05cf261..f20431a 100644 for request in sorted( scheduled_requests.context_requests, key=lambda r: -@@ -83,6 +113,107 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): +@@ -83,6 +113,94 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): self._write_to_file() self._iter += 1 @@ -106,17 +106,8 @@ index 05cf261..f20431a 100644 + def _process_and_store_mooncake( + self, scheduled_requests: ScheduledRequests, + spec_metadata: "SaveHiddenStatesSpecMetadata") -> None: -+ """Stream per-request hidden states to Mooncake instead of disk. -+ -+ The capture buffer holds every context token of this forward packed -+ from offset 0 in scheduled order, so we walk the requests in the same -+ order and slice ``[token_offset : token_offset + num_tokens]`` for -+ each. (Upstream's disk path assumes a single request and always reads -+ from offset 0 — incorrect for the batched prefill TorchSpec runs.) -+ -+ NOTE: assumes the forward packs context-request tokens in the same -+ order as this sort (by ``py_batch_idx``). Validate against real -+ batched runs during engine bring-up. ++ """Walk context_requests in native order -- how the forward packed the ++ buffer. Sorting by py_batch_idx (= py_seq_slot) would misalign slices. + """ + if local_mpi_rank() != 0: + return @@ -125,11 +116,7 @@ index 05cf261..f20431a 100644 + return + + token_offset = 0 -+ for request in sorted( -+ scheduled_requests.context_requests, -+ key=lambda r: -+ (r.py_batch_idx is None, r.py_batch_idx or r.request_id), -+ ): ++ for request in scheduled_requests.context_requests: + token_ids = list(request.get_tokens(0)) + num_tokens = len(token_ids) + buf = self.hidden_states[token_offset:token_offset + num_tokens] From 3ed4c003def4d661f598a825f209ff3237612b78 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 17:42:18 +0000 Subject: [PATCH 19/28] fix(trtllm): default placement_strategy=inference_first TRT-LLM's MPI tp-workers bind GPU ordinals 0..tp_size-1 regardless of the Ray placement group, so under the default training_first they collide with training on the low GPUs and OOM. inference_first assigns inference those low indices. SGLang honors base_gpu_id and is unaffected, so the override is TRT-config-only. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index 3839265..f7e22d5 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -27,6 +27,8 @@ dataset: training: attention_backend: flex_attention + # TRT tp-workers bind GPU 0..tp-1; keep inference on those to avoid OOM with training. + placement_strategy: inference_first micro_batch_size: 1 draft_accumulation_steps: 1 learning_rate: 1e-4 From 893df69cabeb441e6bfc51e0e61f33e2380b384f Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 22:46:09 +0000 Subject: [PATCH 20/28] fix(trtllm): protect trust_remote_code from extra_args override trust_remote_code is passed explicitly to LLM() but was missing from _PROTECTED_ENGINE_KEYS. A user setting it in inference.trtllm.extra_args (a standard LLM kwarg) would have it survive the filter and be passed twice, raising "TypeError: got multiple values for keyword argument 'trust_remote_code'" at engine init. Add it to the protected set so the explicit config-sourced value wins, matching model/backend/tp_size. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- torchspec/inference/engine/trtllm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index c7927f4..6a302d7 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -63,6 +63,7 @@ "backend", "tensor_parallel_size", "pipeline_parallel_size", + "trust_remote_code", "speculative_config", "kv_cache_config", "disable_overlap_scheduler", From 253df85aeb95a3ed7aee6d6290104af70b6434e1 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 22:46:18 +0000 Subject: [PATCH 21/28] style(trtllm): trim and align comments Tighten TRT-only comments to the project's lean style and align with the sglang reference where applicable. No functional change. - Dockerfile: drop the known-issue CUDA-13 wheel comment; condense the patch-step comment to the one non-obvious gotcha (don't import tensorrt_llm at build time -- libcuda absent until runtime). - run.sh: collapse the tp_size-block derivation comment to one line. - inference_config.py: drop the redundant init_timeout comment (sglang/ vllm declare the same field without one). - trtllm_qwen3_8b.yaml: add the commented eval_data_path/eval_interval placeholders to match the sglang config; consolidate the max_num_tokens note to one line. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 7 +++---- docker/trtllm/v1.3.0rc18/Dockerfile | 11 +++-------- examples/qwen3-8b-single-node/run.sh | 4 +--- torchspec/config/inference_config.py | 1 - 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index f7e22d5..1207602 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -22,6 +22,8 @@ model: dataset: train_data_path: ../examples/data/sample_conversations.jsonl + # eval_data_path: ../examples/data/eval_conversations.jsonl + # eval_interval: 100 chat_template: qwen prompt_key: conversations @@ -59,10 +61,7 @@ inference: mem_fraction_static: 0.7 extra_args: # Any extra TensorRT-LLM LLM kwarg; e.g. cap the per-iteration token budget. - # Must be >= training.max_seq_length: SaveHiddenStates runs each prompt in a - # single prefill (chunked prefill is disabled), so a prompt longer than - # max_num_tokens is rejected outright ("sum of prompt length ... should not - # exceed max_num_tokens") and that sample is dropped. + # Must be >= training.max_seq_length: chunked prefill is off, so prompts longer than max_num_tokens are dropped. max_num_tokens: 16384 mooncake: diff --git a/docker/trtllm/v1.3.0rc18/Dockerfile b/docker/trtllm/v1.3.0rc18/Dockerfile index 330ee44..8fb513c 100644 --- a/docker/trtllm/v1.3.0rc18/Dockerfile +++ b/docker/trtllm/v1.3.0rc18/Dockerfile @@ -7,12 +7,9 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends nvtop rsync dnsutils && \ rm -rf /var/lib/apt/lists/* -# Patch the installed tensorrt_llm package so SaveHiddenStates mode streams -# captured hidden states to Mooncake instead of writing .pt files to disk. -# tensorrt_llm lives in dist-packages (python 3.12 in this release image, same -# location the vLLM image uses). Do NOT import tensorrt_llm at build time — its -# package init loads libcuda.so.1, which is absent until container runtime. -# Apply with -p1 to strip the diff's a/ prefix. +# Patch the installed tensorrt_llm (dist-packages) for Mooncake hidden-state +# capture. Don't import tensorrt_llm at build time to locate it — its __init__ +# loads libcuda.so.1, absent until container runtime; patch the path directly. COPY patches/trtllm/v1.3.0rc18/*.patch /tmp/patches/ RUN cd /usr/local/lib/python3.12/dist-packages && \ for p in /tmp/patches/*.patch; do patch -p1 < "$p"; done && \ @@ -21,8 +18,6 @@ RUN cd /usr/local/lib/python3.12/dist-packages && \ COPY . /root/torchspec RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" -# TorchSpec's dependency pulls the generic CUDA 12 Mooncake wheel. Replace it -# with the CUDA 13 wheel that matches this image family. RUN pip uninstall -y mooncake-transfer-engine || true && \ pip install --no-cache-dir --no-deps --force-reinstall \ mooncake-transfer-engine-cuda13==0.3.11.post1 diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index fed3188..890b013 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -37,9 +37,7 @@ else CONFIG_FILE="$ROOT_DIR/configs/sglang_qwen3_8b.yaml" fi -# Per-backend tp_size override key, derived from the config's engine type. -# engine_type "sgl" lives under the "sglang" config block; vllm/trtllm match 1:1. -# This lets run.sh launch any backend with the same 2-GPU/tp=2 inference layout. +# Derive the tp_size override block from the config's engine type ("sgl" -> "sglang"). ENGINE_TYPE=$(grep -oE "inference_engine_type:[[:space:]]*[a-zA-Z]+" "$CONFIG_FILE" | awk '{print $2}') case "$ENGINE_TYPE" in sgl) TP_BLOCK=sglang ;; diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index 63678f6..8221839 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -126,7 +126,6 @@ class TrtllmConfig: # KV-cache memory fraction (TRT-LLM's KvCacheConfig.free_gpu_memory_fraction). mem_fraction_static: Optional[float] = 0.8 - # TRT-LLM model build + load can be slow; give init a generous timeout. init_timeout: int = 600 # Passthrough: forwarded as-is to the TRT-LLM LLM constructor From 0b4f7c48bf5ba39bb39888eeaf844929e106b30e Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 23:24:46 +0000 Subject: [PATCH 22/28] fix(trtllm): honor store_last_hidden_states in mooncake capture When store_last_hidden_states=false (e.g. the DFlash configs), the engine omits last_hidden_states from the returned tensor metadata, so the fetcher never fetches or deletes the {key}_lhs object -- but the patched resource manager still always wrote it, orphaning one _lhs per sample in Mooncake and filling the segment over long runs. Gate storage on a new TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN env (exported by the engine before LLM construction, propagated to MPI workers) so _lhs is written only when the flag is set. Default path (true) is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/trtllm.patch | 5 +++-- torchspec/inference/engine/trtllm_engine.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch index 4a18a77..3890632 100644 --- a/patches/trtllm/v1.3.0rc18/trtllm.patch +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -66,7 +66,7 @@ index 05cf261..f20431a 100644 for request in sorted( scheduled_requests.context_requests, key=lambda r: -@@ -83,6 +113,94 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): +@@ -83,6 +113,95 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): self._write_to_file() self._iter += 1 @@ -115,6 +115,7 @@ index 05cf261..f20431a 100644 + if store is None: + return + ++ store_last = os.environ.get("TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN", "1") != "0" + token_offset = 0 + for request in scheduled_requests.context_requests: + token_ids = list(request.get_tokens(0)) @@ -147,7 +148,7 @@ index 05cf261..f20431a 100644 + key=key, + hidden_states=aux_hidden_states, + input_ids=input_ids, -+ last_hidden_states=last_hidden_states, ++ last_hidden_states=last_hidden_states if store_last else None, + target=None, + ) + except Exception: diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index 6a302d7..474f262 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -146,6 +146,11 @@ def init( ) self._store_last_hidden_states = getattr(self.args, "store_last_hidden_states", True) + # Tell the patched resource manager whether to store last_hidden_states. + # Set before LLM construction so it propagates to the spawned MPI workers. + os.environ["TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN"] = ( + "1" if self._store_last_hidden_states else "0" + ) self._mooncake_config = mooncake_config self._setup_mooncake_env(mooncake_config) From 5224fe0a4a3cee7209d3ab5d92a3d1d30c639395 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 23:24:56 +0000 Subject: [PATCH 23/28] fix(trtllm): defer tensorrt_llm imports until engine init tensorrt_llm's package __init__ loads libcuda.so.1 on import. With the LLM /SamplingParams/KvCacheConfig/SaveHiddenStatesDecodingConfig imports at module scope, importing torchspec.inference.engine on a host without a CUDA driver raised a non-ImportError loader failure; engine/__init__.py only catches ImportError, so the whole engine package (HF/vLLM/SGLang included) failed before TRT-LLM was ever selected. Move the imports into the methods that use them, matching vllm_engine.py's lazy `from vllm import LLM`. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- torchspec/inference/engine/trtllm_engine.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index 474f262..2a6bcc2 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -47,8 +47,6 @@ import ray import torch from omegaconf import DictConfig, OmegaConf -from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig from torchspec.inference.engine.base import InferenceEngine from torchspec.ray.ray_actor import RayActor @@ -243,6 +241,12 @@ def _resolve_aux_layer_ids(self) -> list[int]: def _init_engine(self, tp_size: int, mem_fraction: float | None) -> None: """Construct the TRT-LLM PyTorch ``LLM`` in SaveHiddenStates mode.""" + # Imported here (not at module scope) so that importing this module + # without a CUDA driver -- e.g. on an HF/vLLM/SGLang-only host -- does + # not trigger tensorrt_llm's libcuda load and break the engine package. + from tensorrt_llm import LLM + from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig + # Pin TRT-LLM's MPI workers to the assigned physical GPUs. Workers map # their local rank onto the visible devices, so without this they would # collide on devices 0..tp_size-1. @@ -393,6 +397,8 @@ def generate( # Prefill-only: SaveHiddenStates forces max_new_tokens=1 internally, but # we set it here too to avoid allocating decode resources. + from tensorrt_llm import SamplingParams + sampling_params = SamplingParams(max_tokens=1) outputs = self._engine.generate(inputs, sampling_params, use_tqdm=False) From 50c0f31007bd5f079d1f046953ab993cf01a0bdd Mon Sep 17 00:00:00 2001 From: chungen04 Date: Thu, 25 Jun 2026 23:25:05 +0000 Subject: [PATCH 24/28] fix(run.sh): make engine-type probe non-fatal under set -e The grep that derives the tp_size override block returns 1 when a config has no literal inference_engine_type line (e.g. set via CLI extra args); under set -euo pipefail that exited the launcher before the ${ENGINE_TYPE :-sglang} fallback could run, regressing the [CONFIG_FILE] [EXTRA_ARGS...] path. Append `|| true` so a no-match yields empty and falls back. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- examples/qwen3-8b-single-node/run.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index 890b013..be2c461 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -38,7 +38,8 @@ else fi # Derive the tp_size override block from the config's engine type ("sgl" -> "sglang"). -ENGINE_TYPE=$(grep -oE "inference_engine_type:[[:space:]]*[a-zA-Z]+" "$CONFIG_FILE" | awk '{print $2}') +# `|| true` so a config without a literal inference_engine_type line falls back below instead of tripping set -e. +ENGINE_TYPE=$(grep -oE "inference_engine_type:[[:space:]]*[a-zA-Z]+" "$CONFIG_FILE" | awk '{print $2}' || true) case "$ENGINE_TYPE" in sgl) TP_BLOCK=sglang ;; *) TP_BLOCK="${ENGINE_TYPE:-sglang}" ;; From 956eecfe5bd299b9eac00142e9610cda6caa8dba Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 26 Jun 2026 19:03:41 +0000 Subject: [PATCH 25/28] feat(trtllm): support data-parallel tp=1 inference (Ray-scoped placement + per-engine keys) Enables N independent single-GPU (tp=1) TRT-LLM inference engines on distinct GPUs, needed for data-parallel inference (e.g. DFlash's 4 tp=1 replicas). Two co-dependent pieces: - GPU placement: single-GPU engines drop the NOSET override so Ray scopes CUDA_VISIBLE_DEVICES per actor to its own GPU; the engine no longer manually pins CVD for tp=1 (TRT's MPI worker maps cuda:0 to the scoped GPU). CVD is set before any CUDA/tensorrt_llm init (also fixes premature-CUDA-init). TP>1 engines keep the prior NOSET + manual contiguous-block pin. - Mooncake keys: each engine has its own request_id counter, so keys collided in the shared store across data-parallel engines. Prefix keys per engine (e{rank}_) on both the write side (patch) and the read side (engine). Without these, 4 tp=1 engines all bound GPU 0 (OOM) and clobbered each other's Mooncake keys (shape mismatch). Validated: DFlash trtllm 10k steps clean. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- patches/trtllm/v1.3.0rc18/trtllm.patch | 7 +++-- torchspec/inference/engine/trtllm_engine.py | 32 ++++++++++++--------- torchspec/inference/factory.py | 13 ++++++++- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/patches/trtllm/v1.3.0rc18/trtllm.patch b/patches/trtllm/v1.3.0rc18/trtllm.patch index 3890632..ae25bf1 100644 --- a/patches/trtllm/v1.3.0rc18/trtllm.patch +++ b/patches/trtllm/v1.3.0rc18/trtllm.patch @@ -66,7 +66,7 @@ index 05cf261..f20431a 100644 for request in sorted( scheduled_requests.context_requests, key=lambda r: -@@ -83,6 +113,95 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): +@@ -83,6 +113,96 @@ class SaveHiddenStatesResourceManager(BaseResourceManager): self._write_to_file() self._iter += 1 @@ -116,6 +116,7 @@ index 05cf261..f20431a 100644 + return + + store_last = os.environ.get("TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN", "1") != "0" ++ key_prefix = os.environ.get("TORCHSPEC_TRTLLM_KEY_PREFIX", "") + token_offset = 0 + for request in scheduled_requests.context_requests: + token_ids = list(request.get_tokens(0)) @@ -142,7 +143,7 @@ index 05cf261..f20431a 100644 + client_key_id = getattr(request, "py_client_id", None) + if client_key_id is None: + client_key_id = request.py_request_id -+ key = _sanitize_mooncake_key(str(client_key_id)) ++ key = _sanitize_mooncake_key(f"{key_prefix}{client_key_id}") + try: + store.put( + key=key, @@ -171,7 +172,7 @@ index 2e1eeb2..7d1bca8 100644 for key, value in os.environ.items() if key.startswith("TRTLLM") or key.startswith("TLLM") + or key.startswith("TORCHSPEC") or key.startswith("MOONCAKE") -+ or key.startswith("MC_") ++ or key.startswith("MC_") or key == "CUDA_VISIBLE_DEVICES" } self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, path=sys.path, diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index 2a6bcc2..eb150bc 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -136,11 +136,20 @@ def init( pp_size = getattr(self.args, "trtllm_pp_size", 1) assert pp_size == 1, f"trtllm_pp_size must be 1, got {pp_size}" + # GPU pinning, before any CUDA/tensorrt_llm init (TRT reads CVD at import + # and picks the device by MPI rank). tp=1: the factory drops the NOSET + # override so Ray scopes CVD to this actor's single GPU -- don't override + # it. tp>1: keep all GPUs visible and pin the contiguous block ourselves. if self.base_gpu_id is not None: - self.local_gpu_id = self.setup_gpu(self.base_gpu_id) + if self.num_gpus_per_engine > 1: + gpu_ids = [str(self.base_gpu_id + i) for i in range(self.num_gpus_per_engine)] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids) + self.local_gpu_id = 0 + torch.cuda.set_device(self.local_gpu_id) + os.environ["LOCAL_RANK"] = "0" logger.info( f"TrtllmEngine rank {self.rank}: base_gpu_id={self.base_gpu_id}, " - f"using local GPU {self.local_gpu_id}" + f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}" ) self._store_last_hidden_states = getattr(self.args, "store_last_hidden_states", True) @@ -149,6 +158,11 @@ def init( os.environ["TORCHSPEC_TRTLLM_STORE_LAST_HIDDEN"] = ( "1" if self._store_last_hidden_states else "0" ) + # Per-engine Mooncake key prefix: each data-parallel engine has its own + # LLM request_id counter, so without a prefix they collide in the shared + # store. Set before LLM build so the MPI worker (the patch) inherits it. + self._key_prefix = f"e{self.rank}_" + os.environ["TORCHSPEC_TRTLLM_KEY_PREFIX"] = self._key_prefix self._mooncake_config = mooncake_config self._setup_mooncake_env(mooncake_config) @@ -247,16 +261,8 @@ def _init_engine(self, tp_size: int, mem_fraction: float | None) -> None: from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig - # Pin TRT-LLM's MPI workers to the assigned physical GPUs. Workers map - # their local rank onto the visible devices, so without this they would - # collide on devices 0..tp_size-1. - if self.base_gpu_id is not None: - gpu_ids = [str(self.base_gpu_id + i) for i in range(self.num_gpus_per_engine)] - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids) - logger.info( - f"TrtllmEngine rank {self.rank}: set " - f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}" - ) + # CUDA_VISIBLE_DEVICES is already scoped to this engine's GPUs in init(), + # before any CUDA/TRT init -- see the comment there. # eagle3_layers_to_capture: aux layers + the final post-norm state (-1). # The resource manager orders -1 last in the capture buffer, which is the @@ -407,7 +413,7 @@ def generate( for i, output in enumerate(outputs): did = data_ids[i] seq_len = len(output.prompt_token_ids) - mooncake_key = _sanitize_mooncake_key(str(output.request_id)) + mooncake_key = _sanitize_mooncake_key(f"{self._key_prefix}{output.request_id}") result: dict[str, Any] = { "mooncake_key": mooncake_key, diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index 8ff298b..24abbf5 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -471,6 +471,17 @@ def _prepare_trtllm_engines( TrtllmRayActor = ray.remote(TrtllmEngine) env_vars = get_torchspec_env_vars() + # Single-GPU engines use Ray's per-actor CVD scoping (Approach 1): drop the + # NOSET override and reserve a whole GPU, so Ray exposes exactly that GPU as + # cuda:0 to the engine and its MPI worker -- N such engines each land on + # their own GPU. Multi-GPU (TP) engines keep NOSET + the engine's manual + # contiguous-block pin and only a placeholder reservation (workers grab the + # real GPUs, untracked by Ray). + single_gpu = gpus_per_engine == 1 + if single_gpu: + env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) + actor_num_gpus = float(gpus_per_engine) if single_gpu else 0.2 + engines = [] for i in range(num_engines): bundle_offset = i * gpus_per_engine @@ -484,7 +495,7 @@ def _prepare_trtllm_engines( engine = TrtllmRayActor.options( num_cpus=0.2, - num_gpus=0.2, + num_gpus=actor_num_gpus, scheduling_strategy=scheduling_strategy, runtime_env={"env_vars": env_vars}, ).remote( From dde112864a0a497c89cba9a5d13cd1d3cf822f8a Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 26 Jun 2026 20:15:26 +0000 Subject: [PATCH 26/28] feat(trtllm): enable DFlash training with the trtllm backend Relax the DFlash engine guard in train_entry.py from sgl-only to {'sgl', 'trtllm'}, and add configs/trtllm_qwen3_8b_dflash.yaml (mirrors the sglang DFlash config; trtllm inference block + inference_first placement). DFlash's aux-layer wiring and the trainer are engine-agnostic, so no other changes are needed for the trtllm path. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b_dflash.yaml | 96 +++++++++++++++++++++++++++++ torchspec/train_entry.py | 4 +- 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 configs/trtllm_qwen3_8b_dflash.yaml diff --git a/configs/trtllm_qwen3_8b_dflash.yaml b/configs/trtllm_qwen3_8b_dflash.yaml new file mode 100644 index 0000000..b488ab3 --- /dev/null +++ b/configs/trtllm_qwen3_8b_dflash.yaml @@ -0,0 +1,96 @@ +# DFlash training config for Qwen3-8B target model — TensorRT-LLM inference backend. +# +# Mirrors configs/sglang_qwen3_8b_dflash.yaml; only the inference backend differs +# (engine type, the trtllm block, and placement_strategy). +# +# NOTE: DFlash is currently gated to inference_engine_type='sgl' in +# train_entry.py; running this requires relaxing that guard to allow 'trtllm'. +# +# GPU allocation: mirrors the sglang DFlash layout (4 inference tp=1 + 4 training); +# adjust inference_num_gpus / training_num_gpus_per_node / tp_size to your budget. +# +# Usage: +# ./examples/qwen3-8b-single-node/run.sh configs/trtllm_qwen3_8b_dflash.yaml + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + draft_model_config: torchspec/config/dflash_draft_config.json + +dataset: + train_data_path: ../examples/data/sample_conversations.jsonl + eval_data_path: null + eval_interval: 100 + chat_template: qwen + prompt_key: conversations + min_loss_tokens: 32 + +training: + attention_backend: flex_attention + # TRT tp-workers bind GPU 0..tp-1; keep inference on those to avoid OOM with training. + placement_strategy: inference_first + micro_batch_size: 1 + draft_accumulation_steps: 2 # was 4 → 2x more optimizer steps + learning_rate: 6e-4 + min_lr: 6e-5 # 10% of peak — prevents LR death in later epochs + weight_decay: 0.01 # AdamW regularization for better generalization + max_concurrent_batches: 1 + max_grad_norm: 1.0 + max_seq_length: 2048 + num_epochs: 3 + seed: 42 + training_num_gpus_per_node: 4 + training_num_nodes: 1 + ttt_length: 7 + fsdp_strategy: FULL_SHARD + fsdp_reduce_dtype: bfloat16 + prefetch_depth: 8 + save_interval: 1000 + save_per_epoch: true + max_checkpoints: 2 + warmup_ratio: 0.04 + + # DFlash-specific parameters + dflash_block_size: 16 + dflash_num_anchors: 512 + dflash_loss_decay_gamma: 7.0 + dflash_num_target_layers: 5 + +inference: + inference_engine_type: trtllm + store_last_hidden_states: false + inference_num_gpus: 4 + inference_num_gpus_per_engine: 1 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + trtllm: + tp_size: 1 + # KV-cache memory fraction. Kept below TRT-LLM's 0.9 default to leave room + # for the SaveHiddenStates capture buffer, which the KV profiler does not + # account for. + mem_fraction_static: 0.7 + extra_args: + # Must be >= training.max_seq_length: chunked prefill is off, so prompts longer than max_num_tokens are dropped. + max_num_tokens: 2048 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + # Hard-pin: master-side TTL is disabled; we rely on our explicit + # batch_remove(force=True) (see mooncake/eagle_store.py). Requires + # mooncake-transfer-engine >= 0.3.10.post1. + enable_hard_pin: true + +output_dir: ./outputs/qwen3-8b-dflash +cache_dir: ./cache/qwen3-8b-dflash +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false diff --git a/torchspec/train_entry.py b/torchspec/train_entry.py index 3b0189f..35a53ac 100644 --- a/torchspec/train_entry.py +++ b/torchspec/train_entry.py @@ -255,9 +255,9 @@ def _validate_and_configure_dflash(args, draft_model_config) -> None: return engine_type = getattr(args, "inference_engine_type", "hf") - if engine_type not in ("vllm", "sgl"): + if engine_type not in ("vllm", "sgl", "trtllm"): raise NotImplementedError( - f"DFlash supports inference_engine_type in ('vllm', 'sgl'), got '{engine_type}'." + f"DFlash supports inference_engine_type in ('vllm', 'sgl', 'trtllm'), got '{engine_type}'." ) if getattr(args, "defer_tokenization", False): raise NotImplementedError("DFlash does not support defer_tokenization=True.") From d29cd75573aa0d1baedf394928aeedac4f104bf3 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Fri, 26 Jun 2026 23:41:15 +0000 Subject: [PATCH 27/28] docs(trtllm): make branch comments self-contained Drop the stale sgl-only guard NOTE from the trtllm dflash config (the guard now permits trtllm), remove a redundant CVD-scoping comment in trtllm_engine, and rewrite the factory GPU-assignment comment to explain the tp==1 vs tp>1 split without the dangling "Approach 1" reference that assumed prior context. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b_dflash.yaml | 3 --- torchspec/inference/engine/trtllm_engine.py | 3 --- torchspec/inference/factory.py | 9 +++------ 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/configs/trtllm_qwen3_8b_dflash.yaml b/configs/trtllm_qwen3_8b_dflash.yaml index b488ab3..5d829ab 100644 --- a/configs/trtllm_qwen3_8b_dflash.yaml +++ b/configs/trtllm_qwen3_8b_dflash.yaml @@ -3,9 +3,6 @@ # Mirrors configs/sglang_qwen3_8b_dflash.yaml; only the inference backend differs # (engine type, the trtllm block, and placement_strategy). # -# NOTE: DFlash is currently gated to inference_engine_type='sgl' in -# train_entry.py; running this requires relaxing that guard to allow 'trtllm'. -# # GPU allocation: mirrors the sglang DFlash layout (4 inference tp=1 + 4 training); # adjust inference_num_gpus / training_num_gpus_per_node / tp_size to your budget. # diff --git a/torchspec/inference/engine/trtllm_engine.py b/torchspec/inference/engine/trtllm_engine.py index eb150bc..2d3c79b 100644 --- a/torchspec/inference/engine/trtllm_engine.py +++ b/torchspec/inference/engine/trtllm_engine.py @@ -261,9 +261,6 @@ def _init_engine(self, tp_size: int, mem_fraction: float | None) -> None: from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig, SaveHiddenStatesDecodingConfig - # CUDA_VISIBLE_DEVICES is already scoped to this engine's GPUs in init(), - # before any CUDA/TRT init -- see the comment there. - # eagle3_layers_to_capture: aux layers + the final post-norm state (-1). # The resource manager orders -1 last in the capture buffer, which is the # split point the patch relies on (aux = [:, :-H], last = [:, -H:]). diff --git a/torchspec/inference/factory.py b/torchspec/inference/factory.py index 24abbf5..1c11859 100644 --- a/torchspec/inference/factory.py +++ b/torchspec/inference/factory.py @@ -471,12 +471,9 @@ def _prepare_trtllm_engines( TrtllmRayActor = ray.remote(TrtllmEngine) env_vars = get_torchspec_env_vars() - # Single-GPU engines use Ray's per-actor CVD scoping (Approach 1): drop the - # NOSET override and reserve a whole GPU, so Ray exposes exactly that GPU as - # cuda:0 to the engine and its MPI worker -- N such engines each land on - # their own GPU. Multi-GPU (TP) engines keep NOSET + the engine's manual - # contiguous-block pin and only a placeholder reservation (workers grab the - # real GPUs, untracked by Ray). + # TRT-LLM's MPI workers bind GPUs themselves, so assignment differs by TP degree: + # tp==1 drops NOSET and reserves a whole GPU so Ray scopes it as cuda:0 per engine; + # tp>1 keeps NOSET, pins a contiguous block in init(), and reserves a 0.2 placeholder. single_gpu = gpus_per_engine == 1 if single_gpu: env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) From 7d23aedd9bec22337907a0512bab9367cdf59758 Mon Sep 17 00:00:00 2001 From: chungen04 Date: Sat, 27 Jun 2026 00:03:48 +0000 Subject: [PATCH 28/28] Remove comment to instruct docker build Signed-off-by: chungen04 --- configs/trtllm_qwen3_8b.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/trtllm_qwen3_8b.yaml b/configs/trtllm_qwen3_8b.yaml index 1207602..b505f76 100644 --- a/configs/trtllm_qwen3_8b.yaml +++ b/configs/trtllm_qwen3_8b.yaml @@ -8,7 +8,6 @@ # Installation: # Use the TensorRT-LLM docker image (docker/trtllm/v1.3.0rc18/Dockerfile), # which ships tensorrt_llm patched for Mooncake hidden-state capture. -# For a local install: pip install -e ".[trtllm]" # # Usage (same launcher as the sglang example, for a fair side-by-side): # ./examples/qwen3-8b-single-node/run.sh configs/trtllm_qwen3_8b.yaml