Skip to content

Commit 03e82f2

Browse files
SiYu WuWuSiYu
authored andcommitted
feat(misc): Profiler support
use --enable_profiling=MODE to enable, currently support torch_profile and nvtx (use with NVIDIA Nsight system) mode
1 parent aff4049 commit 03e82f2

File tree

8 files changed

+261
-7
lines changed

8 files changed

+261
-7
lines changed

lightllm/server/api_cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
572572
default=False,
573573
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
574574
)
575+
parser.add_argument(
576+
"--enable_profiling",
577+
type=str,
578+
choices=["torch_profiler", "nvtx"],
579+
default=None,
580+
help="""Enable profiler support.
581+
This will expose '/profiler_start' and '/profiler_stop' API,
582+
below profiling features will only be enabled in this range.
583+
Options:
584+
'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace',
585+
or set by 'LIGHTLLM_TRACE_DIR' env;
586+
'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System
587+
(you should set it up by yourself).
588+
A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range."""
589+
)
575590
return parser

lightllm/server/api_http.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket):
335335
return
336336

337337

338+
@app.get("/profiler_start")
339+
async def profiler_start() -> Response:
340+
if g_objs.args.enable_profiling:
341+
await g_objs.httpserver_manager.profiler_cmd("start")
342+
return JSONResponse({"status": "ok"})
343+
else:
344+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)
345+
346+
347+
@app.get("/profiler_stop")
348+
async def profiler_stop() -> Response:
349+
if g_objs.args.enable_profiling:
350+
await g_objs.httpserver_manager.profiler_cmd("stop")
351+
return JSONResponse({"status": "ok"})
352+
else:
353+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)
354+
355+
338356
@app.on_event("shutdown")
339357
async def shutdown():
340358
logger.info("Received signal to shutdown. Performing graceful shutdown...")

lightllm/server/httpserver/manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from frozendict import frozendict
1414

1515
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
16-
from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator
16+
from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator
1717
from websockets import ClientConnection
1818
from fastapi import Request
1919
from ..tokenizer import get_tokenizer
@@ -35,6 +35,7 @@
3535
from lightllm.utils.config_utils import get_vocab_size
3636
from lightllm.utils.envs_utils import get_unique_server_name
3737
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
38+
from lightllm.utils.profiler import ProfilerCmd
3839
from rpyc.utils.classic import obtain
3940

4041
logger = init_logger(__name__)
@@ -650,6 +651,16 @@ async def abort(self, group_req_id: int) -> bool:
650651
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
651652
return True
652653

654+
async def profiler_cmd(self, cmd: Literal["start", "stop"]):
655+
receivers = [self.send_to_router]
656+
if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal:
657+
receivers.append(self.send_to_visual)
658+
for receiver in receivers:
659+
receiver.send_pyobj(
660+
ProfilerCmd(cmd),
661+
protocol=pickle.HIGHEST_PROTOCOL,
662+
)
663+
653664
async def recycle_resource_loop(self):
654665
pre_time_mark = time.time()
655666

lightllm/server/router/manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
2727
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
2828
from lightllm.utils.log_utils import init_logger, log_time_ready
29+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2930
from lightllm.server.router.token_load import TokenLoad
3031
from lightllm.server.metrics.manager import MetricClient
3132
from lightllm.common.basemodel.infer_lock import g_router_lock
@@ -106,6 +107,10 @@ def __init__(self, args: StartArgs):
106107
if not self.args.enable_cpu_cache
107108
else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False)
108109
)
110+
111+
self.profiler = (
112+
ProcessProfiler(mode=args.enable_profiling, name="lightllm-router") if args.enable_profiling else None
113+
)
109114
return
110115

111116
async def wait_to_model_ready(self):
@@ -504,16 +509,28 @@ def _multinode_tp_generate_new_batch(self):
504509
raise e
505510
return
506511

512+
async def _profiler_cmd(self, cmd_obj: ProfilerCmd):
513+
self.profiler.cmd(cmd_obj)
514+
515+
cmd = ProfilerCmd(cmd=cmd_obj.cmd)
516+
while not self.shm_reqs_io_buffer.is_empty():
517+
await asyncio.sleep(0.02)
518+
519+
self.shm_reqs_io_buffer.write_obj([cmd])
520+
self.shm_reqs_io_buffer.set_ready()
521+
507522
async def _recv_new_reqs_and_schedule(self):
508523
if not hasattr(self, "recv_max_count"):
509524
self.recv_max_count = 64
510525

511526
try:
512527
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
513528
for _ in range(self.recv_max_count):
514-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
529+
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
515530
if isinstance(recv_req, GroupReqIndexes):
516531
self._add_req(recv_req)
532+
elif isinstance(recv_req, ProfilerCmd):
533+
await self._profiler_cmd(recv_req)
517534
else:
518535
assert False, f"Error Req Inf {recv_req}"
519536

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import threading
66
import torch.distributed as dist
7-
from typing import List, Tuple, Callable, Optional
7+
from typing import Dict, List, Literal, Tuple, Callable, Optional
88
from transformers.configuration_utils import PretrainedConfig
99
from lightllm.utils.infer_utils import set_random_seed
1010
from lightllm.utils.log_utils import init_logger
@@ -39,6 +39,7 @@
3939
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
4040
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
4141
from .multi_level_kv_cache import MultiLevelKvCacheModule
42+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
4243

4344

4445
class ModeBackend:
@@ -231,11 +232,19 @@ def init_model(self, kvargs):
231232
if self.args.mtp_mode:
232233
self.init_mtp_draft_model(kvargs)
233234

235+
self.profiler: Optional[ProcessProfiler] = None
236+
if self.args.enable_profiling:
237+
self.profiler = ProcessProfiler(
238+
mode=self.args.enable_profiling,
239+
name=f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}",
240+
)
241+
self.profiling_active = False
242+
234243
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
235244
# 可以降低 cpu overhead,大幅提升gpu得使用率。
236-
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
245+
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True, name="loop0")
237246
self.infer_loop_thread.start()
238-
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True)
247+
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True, name="loop1")
239248
self.infer_loop_thread1.start()
240249
return
241250

@@ -343,6 +352,14 @@ def _try_read_new_reqs(self):
343352
self._try_read_new_reqs_multinode_tp()
344353
else:
345354
self._try_read_new_reqs_normal()
355+
356+
# on each loop thread
357+
if self.profiler is not None:
358+
if self.profiler.is_active != self.profiling_active:
359+
if self.profiling_active:
360+
self.profiler.start()
361+
else:
362+
self.profiler.stop()
346363
return
347364

348365
def _try_read_new_reqs_normal(self):
@@ -408,6 +425,11 @@ def _read_reqs_buffer_and_init_reqs(self):
408425
if obj.req_id in g_infer_context.requests_mapping:
409426
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
410427
req.infer_aborted = True
428+
elif isinstance(obj, ProfilerCmd):
429+
if obj.cmd == "start":
430+
self.profiling_active = True
431+
elif obj.cmd == "stop":
432+
self.profiling_active = False
411433
else:
412434
assert False, f"error type {type(obj)}"
413435
if init_reqs:

lightllm/server/visualserver/manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pickle
88
import inspect
99
import setproctitle
10-
from typing import List
10+
from typing import List, Union
1111
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
1212
from lightllm.server.core.objs import ShmReqManager, StartArgs
1313

@@ -18,6 +18,7 @@
1818
from lightllm.utils.graceful_utils import graceful_registry
1919
from lightllm.utils.process_check import start_parent_check_thread
2020
from lightllm.utils.envs_utils import get_unique_server_name
21+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2122
from rpyc.utils.classic import obtain
2223

2324

@@ -59,6 +60,9 @@ def __init__(
5960
self.visual_model_rpc_ports = visual_model_rpc_ports
6061
self.send_batch_size = args.visual_send_batch_size
6162
self.shm_req_manager = ShmReqManager()
63+
self.profiler: "ProcessProfiler|None" = (
64+
ProcessProfiler(args.enable_profiling, name="lightllm-visual_server") if args.enable_profiling else None
65+
)
6266

6367
async def wait_to_model_ready(self):
6468

@@ -91,6 +95,7 @@ async def wait_to_model_ready(self):
9195
"quant_type": self.args.vit_quant_type,
9296
"quant_cfg": self.args.vit_quant_cfg,
9397
"max_batch_size": min(self.infer_batch_size // self.vit_dp, 1),
98+
"profiler": self.args.enable_profiling,
9499
}
95100
init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs))
96101
await asyncio.gather(*init_model_ret)
@@ -185,9 +190,19 @@ async def loop_for_netio_req(self):
185190
while True:
186191
try:
187192
for _ in range(self.visual_recv_max_count):
188-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
193+
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
189194
if isinstance(recv_req, GroupReqIndexes):
190195
self.waiting_reqs.append(recv_req)
196+
elif isinstance(recv_req, ProfilerCmd):
197+
self.profiler.cmd(recv_req)
198+
tasks = []
199+
for vit_dp_rank in range(self.vit_dp):
200+
for vit_tp_rank in range(self.vit_tp):
201+
task = asyncio.create_task(
202+
self.model_rpcs[vit_dp_rank][vit_tp_rank].profiler_cmd(recv_req)
203+
)
204+
tasks.append(task)
205+
await asyncio.gather(*tasks)
191206
else:
192207
assert False, f"Error Req Inf {recv_req}"
193208
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightllm.utils.dist_utils import init_vision_distributed_env
2525
from lightllm.utils.graceful_utils import graceful_registry
2626
from lightllm.utils.envs_utils import get_env_start_args
27+
from lightllm.utils.profiler import ProcessProfiler
2728

2829

2930
class VisualModelRpcServer(rpyc.Service):
@@ -42,6 +43,13 @@ def exposed_init_model(self, kvargs):
4243
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
4344
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
4445
self.data_type = kvargs["data_type"]
46+
self.profiler = (
47+
ProcessProfiler(
48+
mode=kvargs["profiler"], name=f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}"
49+
)
50+
if kvargs["profiler"]
51+
else None
52+
)
4553

4654
init_vision_distributed_env(kvargs)
4755
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
@@ -116,6 +124,10 @@ def exposed_encode(self, images: List[ImageItem]):
116124
self.cache_client.root.set_items_embed(ids_to_set)
117125
return
118126

127+
def exposed_profiler_cmd(self, cmd_obj):
128+
cmd_obj = obtain(cmd_obj)
129+
self.profiler.cmd(cmd_obj)
130+
119131

120132
class VisualModelRpcClient:
121133
def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
@@ -138,9 +150,11 @@ async def _func(*args, **kwargs):
138150

139151
self._init_model = async_wrap(self.model.init_model)
140152
self._encode = async_wrap(self.model.encode)
153+
self._profiler_cmd = async_wrap(self.model.profiler_cmd)
141154
else:
142155
self._init_model = self.model.exposed_init_model
143156
self._encode = self.model.exposed_encode
157+
self._profiler_cmd = self.model.exposed_profiler_cmd
144158
return
145159

146160
async def init_model(self, kvargs):
@@ -158,6 +172,14 @@ async def encode(self, images: List[ImageItem]):
158172
else:
159173
return ans
160174

175+
async def profiler_cmd(self, cmd_obj):
176+
ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj)
177+
if self.use_rpc:
178+
await ans
179+
return
180+
else:
181+
return
182+
161183

162184
def _init_env(port, device_id):
163185
# 注册graceful 退出的处理

0 commit comments

Comments
 (0)