Skip to content

Commit 73ebda1

Browse files
committed
format code
1 parent 760c0db commit 73ebda1

5 files changed

Lines changed: 17 additions & 14 deletions

File tree

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def get_max_tokens_across_dp():
227227
if dist_ctx.dist_config.dp > 1:
228228
total_token_current_rank = torch.sum(step_context.q_seqlens).to(step_context.q_seqlens.dtype)
229229
world_size = dist_ctx.dist_config.world_size
230-
total_token_buffer = torch.zeros(world_size, dtype=step_context.q_seqlens.dtype, device=torch.npu.current_device())
230+
total_token_buffer = torch.zeros(world_size,
231+
dtype=step_context.q_seqlens.dtype,
232+
device=torch.npu.current_device())
231233
dist.all_gather_into_tensor(total_token_buffer, total_token_current_rank, dist_ctx.ep_gpu_group)
232234
max_tokens_accros_dp = torch.max(total_token_buffer).item()
233235
else:

lmdeploy/pytorch/backends/dlinfer/moe.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
3-
from dataclasses import dataclass
42
from typing import Callable, List
53

64
import torch
@@ -13,16 +11,16 @@
1311

1412
def get_dist_ctx():
1513
dist_ctx = get_dist_manager().current_context()
16-
17-
return DlinferDistContext(dp_size = dist_ctx.dist_config.dp,
18-
tp_size = dist_ctx.dist_config.tp,
19-
ep_size = dist_ctx.dist_config.ep,
20-
dp_rank = dist_ctx.dp_rank,
21-
tp_rank = dist_ctx.attn_tp_group.rank,
22-
ep_rank = dist_ctx.ep_rank,
23-
max_tokens_accros_dp = 1,
24-
tp_group = dist_ctx.attn_tp_group.gpu_group,
25-
ep_group = dist_ctx.ep_gpu_group)
14+
15+
return DlinferDistContext(dp_size=dist_ctx.dist_config.dp,
16+
tp_size=dist_ctx.dist_config.tp,
17+
ep_size=dist_ctx.dist_config.ep,
18+
dp_rank=dist_ctx.dp_rank,
19+
tp_rank=dist_ctx.attn_tp_group.rank,
20+
ep_rank=dist_ctx.ep_rank,
21+
max_tokens_accros_dp=1,
22+
tp_group=dist_ctx.attn_tp_group.gpu_group,
23+
ep_group=dist_ctx.ep_gpu_group)
2624

2725

2826
class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl):

lmdeploy/pytorch/kernels/dlinfer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from dlinfer.utils.type_annotation import DlinferDistContext
3+
34
from ..default import multinomial_sampling, per_channel_quant
45
from .apply_rotary_pos_emb import apply_rotary_pos_emb
56
from .awq_kernels import awq_linear
@@ -12,6 +13,7 @@
1213
from .rms_norm import rms_norm
1314

1415
__all__ = [
16+
'DlinferDistContext',
1517
'rms_norm',
1618
'apply_rotary_pos_emb',
1719
'awq_linear',

lmdeploy/pytorch/kernels/dlinfer/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import dlinfer.ops as ext_ops
3-
import torch
43
from torch import Tensor
4+
55
from . import DlinferDistContext
66

77

lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import dlinfer.ops as ext_ops
33
from torch import Tensor
4+
45
from . import DlinferDistContext
56

67

0 commit comments

Comments
 (0)