Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llm-api/out_of_tree_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import modeling_opt
5 changes: 4 additions & 1 deletion examples/llm-api/out_of_tree_example/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ Prepare the dataset:
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
```


Run the benchmark:
```
trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi
```

### Serving

Similar to `trtllm-bench` above, `trtllm-serve` also supports the `--custom_module_dirs` option.
23 changes: 22 additions & 1 deletion tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal # Added import
import subprocess # nosec B404
import sys
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence

import click
Expand Down Expand Up @@ -33,6 +34,7 @@
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
from tensorrt_llm.serve.tool_parser import ToolParserFactory
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir

# Global variable to store the Popen object of the child process
_child_p_global: Optional[subprocess.Popen] = None
Expand Down Expand Up @@ -244,6 +246,16 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
{"trt": "tensorrt"}),
default="pytorch",
help="The backend to use to serve the model. Default is pytorch backend.")
@click.option(
"--custom_module_dirs",
type=click.Path(exists=True,
readable=True,
path_type=Path,
resolve_path=True),
default=None,
multiple=True,
help="Paths to custom module directories to import.",
)
@click.option('--log_level',
type=click.Choice(severity_map.keys()),
default='info',
Expand Down Expand Up @@ -366,13 +378,22 @@ def serve(
server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
custom_module_dirs: list[Path]):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
"""
logger.set_level(log_level)

for custom_module_dir in custom_module_dirs:
try:
import_custom_module_from_dir(custom_module_dir)
except Exception as e:
logger.error(
f"Failed to import custom module from {custom_module_dir}: {e}")
raise e

llm_args, _ = get_llm_args(
model=model,
tokenizer=tokenizer,
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,10 @@ def _shutdown_wrapper(self_ref):
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback) -> bool:
def __exit__(
self, exc_type, exc_value, traceback
) -> Literal[
False]: # https://github.com/microsoft/pyright/issues/7009#issuecomment-1894135045
del exc_value, traceback
self.shutdown()
return False # propagate exceptions
Expand Down
145 changes: 107 additions & 38 deletions tests/unittest/_torch/modeling/test_modeling_out_of_tree.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,134 @@
import unittest
import re
from contextlib import nullcontext
from pathlib import Path
from typing import cast

from parameterized import parameterized
import pytest

from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.sampling_params import SamplingParams

# isort: off
from utils.util import unittest_name_func, similar
from utils.util import similar
from utils.llm_data import llm_models_root
# isort: on

from llmapi.apps.openai_server import RemoteOpenAIServer

class TestOutOfTree(unittest.TestCase):

@parameterized.expand([False, True], name_func=unittest_name_func)
def test_llm_api(self, import_oot_code: bool):
if import_oot_code:
# Import out-of-tree modeling code for OPTForCausalLM
import os
import sys
sys.path.append(
os.path.join(
os.path.dirname(__file__),
'../../../../examples/llm-api/out_of_tree_example'))
import modeling_opt # noqa
class TestOutOfTree:

model_dir = str(llm_models_root() / "opt-125m")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
@pytest.fixture
@staticmethod
def oot_path() -> Path:
return Path(
__file__
).parent / ".." / ".." / ".." / ".." / "examples" / "llm-api" / "out_of_tree_example"

if not import_oot_code:
with self.assertRaises(RuntimeError):
# estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward.
# Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing.
llm = LLM(model=model_dir,
kv_cache_config=kv_cache_config,
max_num_tokens=2048)
return

llm = LLM(model=model_dir,
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
disable_overlap_scheduler=True)

prompts = [
@pytest.fixture
@staticmethod
def model_dir() -> Path:
models_root = llm_models_root()
assert models_root is not None
return models_root / "opt-125m"

@pytest.fixture
@staticmethod
def prompts() -> list[str]:
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

references = [
@pytest.fixture
@staticmethod
def references() -> list[str]:
return [
" J.C. and I am a student at",
" not a racist. He is a racist.\n",
" the capital of the French Republic.\n\nThe",
" in the hands of the people.\n\nThe",
]

sampling_params = SamplingParams(max_tokens=10)
with llm:
outputs = llm.generate(prompts, sampling_params=sampling_params)
@pytest.fixture
@staticmethod
def sampling_params() -> SamplingParams:
return SamplingParams(max_tokens=10)

@pytest.fixture
@staticmethod
def max_num_tokens() -> int:
# estimate_max_kv_cache_tokens will create a request of max_num_tokens for forward.
# Default 8192 will exceed the max length of absolute positional embedding in OPT, leading to out of range indexing.
return 2048

@pytest.mark.parametrize("import_oot_code", [False, True])
def test_llm_api(
self,
import_oot_code: bool,
oot_path: Path,
model_dir: Path,
prompts: list[str],
references: list[str],
sampling_params: SamplingParams,
max_num_tokens: int,
monkeypatch: pytest.MonkeyPatch,
):
if import_oot_code:
# Import out-of-tree modeling code for OPTForCausalLM
monkeypatch.syspath_prepend(oot_path)
import modeling_opt # noqa

with (nullcontext() if import_oot_code else
pytest.raises(RuntimeError,
match=".*Executor worker returned error.*")) as ctx:
with LLM(
model=str(model_dir),
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.2),
max_num_tokens=max_num_tokens,
) as llm:
outputs = llm.generate(prompts, sampling_params=sampling_params)

for output, ref in zip(outputs, references):
assert similar(output.outputs[0].text, ref)

if not import_oot_code:
exc_val = cast(pytest.ExceptionInfo, ctx).value
assert re.match(
".*Unknown architecture for AutoModelForCausalLM: OPTForCausalLM.*",
str(exc_val.__cause__),
) is not None

@pytest.mark.parametrize("import_oot_code", [False, True])
def test_serve(
self,
import_oot_code: bool,
oot_path: Path,
model_dir: Path,
prompts: list[str],
references: list[str],
sampling_params: SamplingParams,
max_num_tokens: int,
):
with (nullcontext()
if import_oot_code else pytest.raises(RuntimeError)):
args = []
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
args.extend(["--max_num_tokens", str(max_num_tokens)])
if import_oot_code:
args.extend(["--custom_module_dirs", str(oot_path)])
with RemoteOpenAIServer(str(model_dir), args) as remote_server:
client = remote_server.get_client()
result = client.completions.create(
model="model_name",
prompt=prompts,
max_tokens=sampling_params.max_tokens,
temperature=0.0,
)

for output, ref in zip(outputs, references):
assert similar(output.outputs[0].text, ref)
for choice, ref in zip(result.choices, references):
assert similar(choice.text, ref)