From 7b0ac34cd0fdeab9ef7ee14948d5bc3ce43fc814 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 18 Nov 2025 16:08:54 +0000 Subject: [PATCH] feat: support out-of-tree models in trtllm-serve Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../llm-api/out_of_tree_example/__init__.py | 1 + .../llm-api/out_of_tree_example/readme.md | 5 +- tensorrt_llm/commands/serve.py | 23 ++- tensorrt_llm/llmapi/llm.py | 5 +- .../modeling/test_modeling_out_of_tree.py | 145 +++++++++++++----- 5 files changed, 138 insertions(+), 41 deletions(-) create mode 100644 examples/llm-api/out_of_tree_example/__init__.py diff --git a/examples/llm-api/out_of_tree_example/__init__.py b/examples/llm-api/out_of_tree_example/__init__.py new file mode 100644 index 00000000000..55902216ba6 --- /dev/null +++ b/examples/llm-api/out_of_tree_example/__init__.py @@ -0,0 +1 @@ +from . import modeling_opt diff --git a/examples/llm-api/out_of_tree_example/readme.md b/examples/llm-api/out_of_tree_example/readme.md index 1b26ea3cd67..d93981bb41e 100644 --- a/examples/llm-api/out_of_tree_example/readme.md +++ b/examples/llm-api/out_of_tree_example/readme.md @@ -45,8 +45,11 @@ Prepare the dataset: python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl ``` - Run the benchmark: ``` trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi ``` + +### Serving + +Similar to `trtllm-bench` above, `trtllm-serve` also supports the `--custom_module_dirs` option. diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index f43f5c4838d..5c6a07c9078 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -5,6 +5,7 @@ import signal # Added import import subprocess # nosec B404 import sys +from pathlib import Path from typing import Any, Dict, Mapping, Optional, Sequence import click @@ -33,6 +34,7 @@ from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer from tensorrt_llm.serve.tool_parser import ToolParserFactory +from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir # Global variable to store the Popen object of the child process _child_p_global: Optional[subprocess.Popen] = None @@ -244,6 +246,16 @@ def convert(self, value: Any, param: Optional["click.Parameter"], {"trt": "tensorrt"}), default="pytorch", help="The backend to use to serve the model. Default is pytorch backend.") +@click.option( + "--custom_module_dirs", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + multiple=True, + help="Paths to custom module directories to import.", +) @click.option('--log_level', type=click.Choice(severity_map.keys()), default='info', @@ -366,13 +378,22 @@ def serve( server_role: Optional[str], fail_fast_on_attention_window_too_large: bool, otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool, - disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]): + disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str], + custom_module_dirs: list[Path]): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path """ logger.set_level(log_level) + for custom_module_dir in custom_module_dirs: + try: + import_custom_module_from_dir(custom_module_dir) + except Exception as e: + logger.error( + f"Failed to import custom module from {custom_module_dir}: {e}") + raise e + llm_args, _ = get_llm_args( model=model, tokenizer=tokenizer, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 32c6a90e327..0ca965d0e08 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -776,7 +776,10 @@ def _shutdown_wrapper(self_ref): def __enter__(self): return self - def __exit__(self, exc_type, exc_value, traceback) -> bool: + def __exit__( + self, exc_type, exc_value, traceback + ) -> Literal[ + False]: # https://github.com/microsoft/pyright/issues/7009#issuecomment-1894135045 del exc_value, traceback self.shutdown() return False # propagate exceptions diff --git a/tests/unittest/_torch/modeling/test_modeling_out_of_tree.py b/tests/unittest/_torch/modeling/test_modeling_out_of_tree.py index ffbf2c94671..0916fbbe3e9 100644 --- a/tests/unittest/_torch/modeling/test_modeling_out_of_tree.py +++ b/tests/unittest/_torch/modeling/test_modeling_out_of_tree.py @@ -1,65 +1,134 @@ -import unittest +import re +from contextlib import nullcontext +from pathlib import Path +from typing import cast -from parameterized import parameterized +import pytest from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.sampling_params import SamplingParams # isort: off -from utils.util import unittest_name_func, similar +from utils.util import similar from utils.llm_data import llm_models_root # isort: on +from llmapi.apps.openai_server import RemoteOpenAIServer -class TestOutOfTree(unittest.TestCase): - @parameterized.expand([False, True], name_func=unittest_name_func) - def test_llm_api(self, import_oot_code: bool): - if import_oot_code: - # Import out-of-tree modeling code for OPTForCausalLM - import os - import sys - sys.path.append( - os.path.join( - os.path.dirname(__file__), - '../../../../examples/llm-api/out_of_tree_example')) - import modeling_opt # noqa +class TestOutOfTree: - model_dir = str(llm_models_root() / "opt-125m") - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + @pytest.fixture + @staticmethod + def oot_path() -> Path: + return Path( + __file__ + ).parent / ".." / ".." / ".." / ".." / "examples" / "llm-api" / "out_of_tree_example" - if not import_oot_code: - with self.assertRaises(RuntimeError): - # estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward. - # Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing. - llm = LLM(model=model_dir, - kv_cache_config=kv_cache_config, - max_num_tokens=2048) - return - - llm = LLM(model=model_dir, - kv_cache_config=kv_cache_config, - max_num_tokens=2048, - disable_overlap_scheduler=True) - - prompts = [ + @pytest.fixture + @staticmethod + def model_dir() -> Path: + models_root = llm_models_root() + assert models_root is not None + return models_root / "opt-125m" + + @pytest.fixture + @staticmethod + def prompts() -> list[str]: + return [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] - references = [ + @pytest.fixture + @staticmethod + def references() -> list[str]: + return [ " J.C. and I am a student at", " not a racist. He is a racist.\n", " the capital of the French Republic.\n\nThe", " in the hands of the people.\n\nThe", ] - sampling_params = SamplingParams(max_tokens=10) - with llm: - outputs = llm.generate(prompts, sampling_params=sampling_params) + @pytest.fixture + @staticmethod + def sampling_params() -> SamplingParams: + return SamplingParams(max_tokens=10) + + @pytest.fixture + @staticmethod + def max_num_tokens() -> int: + # estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward. + # Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing. + return 2048 + + @pytest.mark.parametrize("import_oot_code", [False, True]) + def test_llm_api( + self, + import_oot_code: bool, + oot_path: Path, + model_dir: Path, + prompts: list[str], + references: list[str], + sampling_params: SamplingParams, + max_num_tokens: int, + monkeypatch: pytest.MonkeyPatch, + ): + if import_oot_code: + # Import out-of-tree modeling code for OPTForCausalLM + monkeypatch.syspath_prepend(oot_path) + import modeling_opt # noqa + + with (nullcontext() if import_oot_code else + pytest.raises(RuntimeError, + match=".*Executor worker returned error.*")) as ctx: + with LLM( + model=str(model_dir), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.2), + max_num_tokens=max_num_tokens, + ) as llm: + outputs = llm.generate(prompts, sampling_params=sampling_params) + + for output, ref in zip(outputs, references): + assert similar(output.outputs[0].text, ref) + + if not import_oot_code: + exc_val = cast(pytest.ExceptionInfo, ctx).value + assert re.match( + ".*Unknown architecture for AutoModelForCausalLM: OPTForCausalLM.*", + str(exc_val.__cause__), + ) is not None + + @pytest.mark.parametrize("import_oot_code", [False, True]) + def test_serve( + self, + import_oot_code: bool, + oot_path: Path, + model_dir: Path, + prompts: list[str], + references: list[str], + sampling_params: SamplingParams, + max_num_tokens: int, + ): + with (nullcontext() + if import_oot_code else pytest.raises(RuntimeError)): + args = [] + args.extend(["--kv_cache_free_gpu_memory_fraction", + "0.2"]) # for co-existence with other servers + args.extend(["--max_num_tokens", str(max_num_tokens)]) + if import_oot_code: + args.extend(["--custom_module_dirs", str(oot_path)]) + with RemoteOpenAIServer(str(model_dir), args) as remote_server: + client = remote_server.get_client() + result = client.completions.create( + model="model_name", + prompt=prompts, + max_tokens=sampling_params.max_tokens, + temperature=0.0, + ) - for output, ref in zip(outputs, references): - assert similar(output.outputs[0].text, ref) + for choice, ref in zip(result.choices, references): + assert similar(choice.text, ref)