|
4 | 4 | import time |
5 | 5 | import threading |
6 | 6 | import torch.distributed as dist |
7 | | -from typing import List, Tuple, Callable, Optional |
| 7 | +from typing import Dict, List, Literal, Tuple, Callable, Optional |
8 | 8 | from transformers.configuration_utils import PretrainedConfig |
9 | 9 | from lightllm.utils.infer_utils import set_random_seed |
10 | 10 | from lightllm.utils.log_utils import init_logger |
|
39 | 39 | from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token |
40 | 40 | from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet |
41 | 41 | from .multi_level_kv_cache import MultiLevelKvCacheModule |
| 42 | +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd |
42 | 43 |
|
43 | 44 |
|
44 | 45 | class ModeBackend: |
@@ -231,11 +232,19 @@ def init_model(self, kvargs): |
231 | 232 | if self.args.mtp_mode: |
232 | 233 | self.init_mtp_draft_model(kvargs) |
233 | 234 |
|
| 235 | + self.profiler: Optional[ProcessProfiler] = None |
| 236 | + if self.args.enable_profiling: |
| 237 | + self.profiler = ProcessProfiler( |
| 238 | + mode=self.args.enable_profiling, |
| 239 | + name=f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}", |
| 240 | + ) |
| 241 | + self.profiling_active = False |
| 242 | + |
234 | 243 | # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 |
235 | 244 | # 可以降低 cpu overhead,大幅提升gpu得使用率。 |
236 | | - self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) |
| 245 | + self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True, name="loop0") |
237 | 246 | self.infer_loop_thread.start() |
238 | | - self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True) |
| 247 | + self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True, name="loop1") |
239 | 248 | self.infer_loop_thread1.start() |
240 | 249 | return |
241 | 250 |
|
@@ -343,6 +352,14 @@ def _try_read_new_reqs(self): |
343 | 352 | self._try_read_new_reqs_multinode_tp() |
344 | 353 | else: |
345 | 354 | self._try_read_new_reqs_normal() |
| 355 | + |
| 356 | + # on each loop thread |
| 357 | + if self.profiler is not None: |
| 358 | + if self.profiler.is_active != self.profiling_active: |
| 359 | + if self.profiling_active: |
| 360 | + self.profiler.start() |
| 361 | + else: |
| 362 | + self.profiler.stop() |
346 | 363 | return |
347 | 364 |
|
348 | 365 | def _try_read_new_reqs_normal(self): |
@@ -408,6 +425,11 @@ def _read_reqs_buffer_and_init_reqs(self): |
408 | 425 | if obj.req_id in g_infer_context.requests_mapping: |
409 | 426 | req: InferReq = g_infer_context.requests_mapping[obj.req_id] |
410 | 427 | req.infer_aborted = True |
| 428 | + elif isinstance(obj, ProfilerCmd): |
| 429 | + if obj.cmd == "start": |
| 430 | + self.profiling_active = True |
| 431 | + elif obj.cmd == "stop": |
| 432 | + self.profiling_active = False |
411 | 433 | else: |
412 | 434 | assert False, f"error type {type(obj)}" |
413 | 435 | if init_reqs: |
|
0 commit comments