|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | | - |
3 | | -from dataclasses import dataclass |
4 | 2 | from typing import Callable, List |
5 | 3 |
|
6 | 4 | import torch |
|
13 | 11 |
|
14 | 12 | def get_dist_ctx(): |
15 | 13 | dist_ctx = get_dist_manager().current_context() |
16 | | - |
17 | | - return DlinferDistContext(dp_size = dist_ctx.dist_config.dp, |
18 | | - tp_size = dist_ctx.dist_config.tp, |
19 | | - ep_size = dist_ctx.dist_config.ep, |
20 | | - dp_rank = dist_ctx.dp_rank, |
21 | | - tp_rank = dist_ctx.attn_tp_group.rank, |
22 | | - ep_rank = dist_ctx.ep_rank, |
23 | | - max_tokens_accros_dp = 1, |
24 | | - tp_group = dist_ctx.attn_tp_group.gpu_group, |
25 | | - ep_group = dist_ctx.ep_gpu_group) |
| 14 | + |
| 15 | + return DlinferDistContext(dp_size=dist_ctx.dist_config.dp, |
| 16 | + tp_size=dist_ctx.dist_config.tp, |
| 17 | + ep_size=dist_ctx.dist_config.ep, |
| 18 | + dp_rank=dist_ctx.dp_rank, |
| 19 | + tp_rank=dist_ctx.attn_tp_group.rank, |
| 20 | + ep_rank=dist_ctx.ep_rank, |
| 21 | + max_tokens_accros_dp=1, |
| 22 | + tp_group=dist_ctx.attn_tp_group.gpu_group, |
| 23 | + ep_group=dist_ctx.ep_gpu_group) |
26 | 24 |
|
27 | 25 |
|
28 | 26 | class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl): |
|
0 commit comments