Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ def __init__(self, vllm_config):
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.denseffn_tensor_parallel_size = additional_config.get(
"denseffn_tensor_parallel_size", None)
if self.denseffn_tensor_parallel_size is not None:
logger.info(
f"Enable denseffn_tensor_parallel_size={self.denseffn_tensor_parallel_size} for DenseFFN layers on decode nodes."
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"denseffn_tensor_parallel_size is only supported when engine's tensor_parallel_size is 1."
)
if not self.torchair_graph_config.enabled:
raise AssertionError(
"denseffn_tensor_parallel_size is only supported in graph mode."
)
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"denseffn_tensor_parallel_size is only supported on decode nodes in a pipeline parallel setup."
)
self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False)
self.pd_tp_ratio = 1
Expand Down
29 changes: 29 additions & 0 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
_DFTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None

Expand All @@ -38,6 +39,12 @@ def get_lmhead_tp_group() -> GroupCoordinator:
return _LMTP


def get_dftp_group() -> GroupCoordinator:
assert _DFTP is not None, (
"denseffn tensor parallel group is not initialized")
return _DFTP


def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP

Expand Down Expand Up @@ -179,6 +186,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="lmheadtp")

denseffn_tensor_parallel_size = get_ascend_config(
).denseffn_tensor_parallel_size
if denseffn_tensor_parallel_size is not None:
group_ranks = []
global _DFTP
num_denseffn_tensor_parallel_groups: int = (
world_size // denseffn_tensor_parallel_size)
for i in range(num_denseffn_tensor_parallel_groups):
ranks = list(
range(i * denseffn_tensor_parallel_size,
(i + 1) * denseffn_tensor_parallel_size))
group_ranks.append(ranks)
_DFTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="denseffntp")

# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
Expand Down Expand Up @@ -258,6 +282,11 @@ def destroy_ascend_model_parallel():
_P_TP.destroy()
_P_TP = None

global _DFTP
if _DFTP:
_DFTP.destroy()
_DFTP = None

global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size != 1:
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
is_first_k_dense)


class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.is_first_k_dense = is_first_k_dense(prefix)


class AscendQKVParallelLinear(QKVParallelLinear):
Expand Down
30 changes: 21 additions & 9 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,17 @@
from vllm.forward_context import get_forward_context

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
from vllm_ascend.distributed.parallel_state import (get_dftp_group,
get_flashcomm2_odp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
flashcomm2_enable,
from vllm_ascend.utils import (dense_optim_enable, denseffn_tp_enable,
enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
is_first_k_dense, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable,
shared_expert_dp_enabled)


class CustomLinearOp:
Expand Down Expand Up @@ -159,7 +161,10 @@ def __init__(self, layer):

@property
def comm_group(self):
return get_mlp_tp_group()
if mlp_tp_enable():
return get_mlp_tp_group()
else:
return get_dftp_group()

def apply_impl(
self,
Expand All @@ -182,7 +187,10 @@ def __init__(self, layer):

@property
def comm_group(self):
return get_mlp_tp_group()
if mlp_tp_enable():
return get_mlp_tp_group()
else:
return get_dftp_group()

def apply_impl(
self, input_: torch.Tensor
Expand Down Expand Up @@ -605,7 +613,9 @@ def update_attrs(self):
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
if (mlp_tp_enable() or
(denseffn_tp_enable()
and is_first_k_dense(prefix))) and "gate_up_proj" in prefix:
return MLPColumnParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
Expand All @@ -625,7 +635,9 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
if "down_proj" in prefix and (mlp_tp_enable() or
(denseffn_tp_enable()
and is_first_k_dense(prefix))):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
Expand Down
17 changes: 12 additions & 5 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
from vllm_ascend.distributed.parallel_state import (get_dftp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
mlp_tp_enable, oproj_tp_enable)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, denseffn_tp_enable,
flashcomm2_enable, mlp_tp_enable,
oproj_tp_enable)

from .utils import get_quant_method

Expand Down Expand Up @@ -348,8 +350,13 @@ def apply(
if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and (
mlp_tp_enable() or
(denseffn_tp_enable() and layer.is_first_k_dense)):
if denseffn_tp_enable() and layer.is_first_k_dense:
tp_rank = get_dftp_group().rank_in_group
else:
tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config(
Expand Down
28 changes: 28 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import math
import os
import re
from contextlib import contextmanager, nullcontext
from enum import Enum
from threading import Lock
Expand Down Expand Up @@ -768,6 +769,10 @@ def oproj_tp_enable() -> bool:
return get_ascend_config().oproj_tensor_parallel_size is not None


def denseffn_tp_enable() -> bool:
return get_ascend_config().denseffn_tensor_parallel_size is not None


def mlp_tp_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE

Expand Down Expand Up @@ -1014,3 +1019,26 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
reorgnized_batch_ids.append(ranks)

return reorgnized_batch_ids


def is_first_k_dense(prefix: str) -> bool:
from vllm.config import get_current_vllm_config
match = re.search(r'layers\.(\d+)\.', prefix)
if not match:
return False

layer_idx = int(match.group(1))

vllm_config = get_current_vllm_config()
if vllm_config is None:
raise ValueError(
"get_current_vllm_config() returned None. "
"Ensure this function is called within the model initialization context."
)
config = vllm_config.model_config.hf_config

is_moe_layer = (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0)

return not is_moe_layer
Loading