Skip to content

Commit e477f07

Browse files
committed
feat: support out-of-tree models in trtllm-serve
Signed-off-by: ixlmar <[email protected]>
1 parent 07343bb commit e477f07

File tree

5 files changed

+138
-41
lines changed

5 files changed

+138
-41
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import modeling_opt

examples/llm-api/out_of_tree_example/readme.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ Prepare the dataset:
4545
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
4646
```
4747
48-
4948
Run the benchmark:
5049
```
5150
trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi
5251
```
52+
53+
### Serving
54+
55+
Similar to `trtllm-bench` above, `trtllm-serve` also supports the `--custom_module_dirs` option.

tensorrt_llm/commands/serve.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import signal # Added import
66
import subprocess # nosec B404
77
import sys
8+
from pathlib import Path
89
from typing import Any, Dict, Mapping, Optional, Sequence
910

1011
import click
@@ -33,6 +34,7 @@
3334
from tensorrt_llm.logger import logger, severity_map
3435
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
3536
from tensorrt_llm.serve.tool_parser import ToolParserFactory
37+
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
3638

3739
# Global variable to store the Popen object of the child process
3840
_child_p_global: Optional[subprocess.Popen] = None
@@ -244,6 +246,16 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
244246
{"trt": "tensorrt"}),
245247
default="pytorch",
246248
help="The backend to use to serve the model. Default is pytorch backend.")
249+
@click.option(
250+
"--custom_module_dirs",
251+
type=click.Path(exists=True,
252+
readable=True,
253+
path_type=Path,
254+
resolve_path=True),
255+
default=None,
256+
multiple=True,
257+
help="Paths to custom module directories to import.",
258+
)
247259
@click.option('--log_level',
248260
type=click.Choice(severity_map.keys()),
249261
default='info',
@@ -366,13 +378,22 @@ def serve(
366378
server_role: Optional[str],
367379
fail_fast_on_attention_window_too_large: bool,
368380
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
369-
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
381+
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
382+
custom_module_dirs: list[Path]):
370383
"""Running an OpenAI API compatible server
371384
372385
MODEL: model name | HF checkpoint path | TensorRT engine path
373386
"""
374387
logger.set_level(log_level)
375388

389+
for custom_module_dir in custom_module_dirs:
390+
try:
391+
import_custom_module_from_dir(custom_module_dir)
392+
except Exception as e:
393+
logger.error(
394+
f"Failed to import custom module from {custom_module_dir}: {e}")
395+
raise e
396+
376397
llm_args, _ = get_llm_args(
377398
model=model,
378399
tokenizer=tokenizer,

tensorrt_llm/llmapi/llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,10 @@ def _shutdown_wrapper(self_ref):
776776
def __enter__(self):
777777
return self
778778

779-
def __exit__(self, exc_type, exc_value, traceback) -> bool:
779+
def __exit__(
780+
self, exc_type, exc_value, traceback
781+
) -> Literal[
782+
False]: # https://github.com/microsoft/pyright/issues/7009#issuecomment-1894135045
780783
del exc_value, traceback
781784
self.shutdown()
782785
return False # propagate exceptions
Lines changed: 107 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,134 @@
1-
import unittest
1+
import re
2+
import sys
3+
from contextlib import nullcontext
4+
from pathlib import Path
5+
from typing import cast
26

3-
from parameterized import parameterized
7+
import pytest
48

59
from tensorrt_llm import LLM
610
from tensorrt_llm.llmapi import KvCacheConfig
711
from tensorrt_llm.sampling_params import SamplingParams
812

913
# isort: off
10-
from utils.util import unittest_name_func, similar
14+
from utils.util import similar
1115
from utils.llm_data import llm_models_root
1216
# isort: on
1317

18+
from llmapi.apps.openai_server import RemoteOpenAIServer
1419

15-
class TestOutOfTree(unittest.TestCase):
1620

17-
@parameterized.expand([False, True], name_func=unittest_name_func)
18-
def test_llm_api(self, import_oot_code: bool):
19-
if import_oot_code:
20-
# Import out-of-tree modeling code for OPTForCausalLM
21-
import os
22-
import sys
23-
sys.path.append(
24-
os.path.join(
25-
os.path.dirname(__file__),
26-
'../../../../examples/llm-api/out_of_tree_example'))
27-
import modeling_opt # noqa
21+
class TestOutOfTree:
2822

29-
model_dir = str(llm_models_root() / "opt-125m")
30-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
23+
@pytest.fixture
24+
@staticmethod
25+
def oot_path() -> Path:
26+
return Path(
27+
__file__
28+
).parent / ".." / ".." / ".." / ".." / "examples" / "llm-api" / "out_of_tree_example"
3129

32-
if not import_oot_code:
33-
with self.assertRaises(RuntimeError):
34-
# estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward.
35-
# Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing.
36-
llm = LLM(model=model_dir,
37-
kv_cache_config=kv_cache_config,
38-
max_num_tokens=2048)
39-
return
40-
41-
llm = LLM(model=model_dir,
42-
kv_cache_config=kv_cache_config,
43-
max_num_tokens=2048,
44-
disable_overlap_scheduler=True)
45-
46-
prompts = [
30+
@pytest.fixture
31+
@staticmethod
32+
def model_dir() -> Path:
33+
models_root = llm_models_root()
34+
assert models_root is not None
35+
return models_root / "opt-125m"
36+
37+
@pytest.fixture
38+
@staticmethod
39+
def prompts() -> list[str]:
40+
return [
4741
"Hello, my name is",
4842
"The president of the United States is",
4943
"The capital of France is",
5044
"The future of AI is",
5145
]
5246

53-
references = [
47+
@pytest.fixture
48+
@staticmethod
49+
def references() -> list[str]:
50+
return [
5451
" J.C. and I am a student at",
5552
" not a racist. He is a racist.\n",
5653
" the capital of the French Republic.\n\nThe",
5754
" in the hands of the people.\n\nThe",
5855
]
5956

60-
sampling_params = SamplingParams(max_tokens=10)
61-
with llm:
62-
outputs = llm.generate(prompts, sampling_params=sampling_params)
57+
@pytest.fixture
58+
@staticmethod
59+
def sampling_params() -> SamplingParams:
60+
return SamplingParams(max_tokens=10)
61+
62+
@pytest.fixture
63+
@staticmethod
64+
def max_num_tokens() -> int:
65+
# estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward.
66+
# Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing.
67+
return 2048
68+
69+
@pytest.mark.parametrize("import_oot_code", [False, True])
70+
def test_llm_api(
71+
self,
72+
import_oot_code: bool,
73+
oot_path: Path,
74+
model_dir: Path,
75+
prompts: list[str],
76+
references: list[str],
77+
sampling_params: SamplingParams,
78+
max_num_tokens: int,
79+
):
80+
if import_oot_code:
81+
# Import out-of-tree modeling code for OPTForCausalLM
82+
sys.path.append(str(oot_path))
83+
import modeling_opt # noqa
84+
85+
with (nullcontext() if import_oot_code else
86+
pytest.raises(RuntimeError,
87+
match=".*Executor worker returned error.*")) as ctx:
88+
with LLM(
89+
model=str(model_dir),
90+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
91+
max_num_tokens=max_num_tokens,
92+
) as llm:
93+
outputs = llm.generate(prompts, sampling_params=sampling_params)
94+
95+
for output, ref in zip(outputs, references):
96+
assert similar(output.outputs[0].text, ref)
97+
98+
if not import_oot_code:
99+
exc_val = cast(pytest.ExceptionInfo, ctx).value
100+
assert re.match(
101+
".*Unknown architecture for AutoModelForCausalLM: OPTForCausalLM.*",
102+
str(exc_val.__cause__),
103+
) is not None
104+
105+
@pytest.mark.parametrize("import_oot_code", [False, True])
106+
def test_serve(
107+
self,
108+
import_oot_code: bool,
109+
oot_path: Path,
110+
model_dir: Path,
111+
prompts: list[str],
112+
references: list[str],
113+
sampling_params: SamplingParams,
114+
max_num_tokens: int,
115+
):
116+
with (nullcontext()
117+
if import_oot_code else pytest.raises(RuntimeError)):
118+
args = []
119+
args.extend(["--kv_cache_free_gpu_memory_fraction",
120+
"0.2"]) # for co-existence with other servers
121+
args.extend(["--max_num_tokens", str(max_num_tokens)])
122+
if import_oot_code:
123+
args.extend(["--custom_module_dirs", str(oot_path)])
124+
with RemoteOpenAIServer(str(model_dir), args) as remote_server:
125+
client = remote_server.get_client()
126+
result = client.completions.create(
127+
model="model_name",
128+
prompt=prompts,
129+
max_tokens=sampling_params.max_tokens,
130+
temperature=0.0,
131+
)
63132

64-
for output, ref in zip(outputs, references):
65-
assert similar(output.outputs[0].text, ref)
133+
for choice, ref in zip(result.choices, references):
134+
assert similar(choice.text, ref)

0 commit comments

Comments
 (0)