Skip to content

Commit 10fe499

Browse files
authored
Merge branch 'develop' into add_clear_run_batch_ci
2 parents 7b9fc83 + b467e9d commit 10fe499

File tree

11 files changed

+406
-6
lines changed

11 files changed

+406
-6
lines changed

fastdeploy/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ def __init__(
540540
self.expert_parallel_size = 1 # EP degree
541541
self.data_parallel_size = 1 # DP degree
542542
self.enable_expert_parallel = False
543+
self.enable_chunked_moe = False
544+
self.chunked_moe_size = 256
545+
self.max_moe_num_chunk = 1
546+
self.moe_num_chunk = 1
547+
543548
self.local_data_parallel_id = 0
544549
# Engine worker queue port
545550
self.engine_worker_queue_port: str = "9923"

fastdeploy/engine/args_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,16 @@ class EngineArgs:
286286
Enable expert parallelism.
287287
"""
288288

289+
enable_chunked_moe: bool = False
290+
"""
291+
Whether use chunked moe.
292+
"""
293+
294+
chunked_moe_size: int = 256
295+
"""
296+
Chunk size of moe input.
297+
"""
298+
289299
cache_transfer_protocol: str = "ipc"
290300
"""
291301
Protocol to use for cache transfer.
@@ -870,6 +880,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
870880
default=EngineArgs.eplb_config,
871881
help="Config of eplb.",
872882
)
883+
parallel_group.add_argument(
884+
"--enable-chunked-moe",
885+
action="store_true",
886+
default=EngineArgs.enable_chunked_moe,
887+
help="Use chunked moe.",
888+
)
889+
parallel_group.add_argument(
890+
"--chunked-moe-size",
891+
type=int,
892+
default=EngineArgs.chunked_moe_size,
893+
help="Chunked size of moe input.",
894+
)
873895

874896
# Load group
875897
load_group = parser.add_argument_group("Load Configuration")

fastdeploy/engine/async_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def _start_worker_service(self):
812812
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
813813
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
814814
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
815+
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
815816
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
816817
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
817818
f" --ori_vocab_size {ori_vocab_size}"

fastdeploy/engine/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def _start_worker_service(self):
544544
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
545545
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
546546
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
547+
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
547548
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
548549
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
549550
f" --ori_vocab_size {ori_vocab_size}"
@@ -573,6 +574,7 @@ def _start_worker_service(self):
573574

574575
worker_store_true_flag = {
575576
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
577+
"enable_chunked_moe": self.cfg.parallel_config.enable_chunked_moe,
576578
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
577579
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
578580
"do_profile": self.do_profile,

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
612612
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
613613
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
614614
out = multi_outs[:token_num, :]
615+
615616
return out
616617

617618
def forward(self, x: paddle.Tensor, gate: nn.Layer):
@@ -633,9 +634,63 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer):
633634
and token_num >= self.tp_size
634635
):
635636
out = self.forward_split_allgather(x, gate)
637+
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
638+
out = self.forward_chunked_moe(x, gate)
636639
else:
637-
out = self.quant_method.apply(self, x, gate)
640+
out = self.forward_normal(x, gate)
638641

639642
if self.reduce_results and self.tp_size > 1:
640643
out = tensor_model_parallel_all_reduce(out, self.tp_group)
641644
return out
645+
646+
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
647+
"""
648+
Split input to multi chunk to reduce the memory usage of moe.
649+
650+
Args:
651+
x (Tensor): Input tensor to the moe layer.
652+
653+
Returns:
654+
Tensor: Output tensor.s
655+
"""
656+
chunk_size = self.fd_config.parallel_config.chunked_moe_size
657+
token_num = x.shape[0]
658+
fake_x = paddle.empty(
659+
shape=[0, self.fd_config.model_config.hidden_size],
660+
dtype=paddle.get_default_dtype(),
661+
)
662+
# input size that are less than a chunk, less than the max size data or empty input
663+
# need to be repeated until the max chunk data infer MOE finished.
664+
if token_num > chunk_size: # chunked moe
665+
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
666+
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
667+
668+
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
669+
if i < self.fd_config.parallel_config.moe_num_chunk:
670+
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
671+
else:
672+
# just need to use real data to infer max_moe_num_chunk times.
673+
self.quant_method.apply(self, fake_x, gate)
674+
675+
out = paddle.concat(out_split_list, axis=0)
676+
else:
677+
# when only one chunk, just need to use real data to infer once.
678+
out = self.quant_method.apply(self, x, gate)
679+
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
680+
self.quant_method.apply(self, fake_x, gate)
681+
682+
return out
683+
684+
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
685+
"""
686+
Normal mode of forward.
687+
688+
Args:
689+
x (Tensor): Input tensor to the moe layer.
690+
691+
Returns:
692+
Tensor: Output tensor.s
693+
694+
"""
695+
out = self.quant_method.apply(self, x, gate)
696+
return out

fastdeploy/worker/gpu_model_runner.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@
9595
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
9696
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
9797
from fastdeploy.output.pooler import PoolerOutput
98-
from fastdeploy.worker.model_runner_base import ModelRunnerBase
98+
from fastdeploy.worker.model_runner_base import (
99+
DistributedOut,
100+
DistributedStatus,
101+
ModelRunnerBase,
102+
)
99103
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput
100104

101105

@@ -250,6 +254,56 @@ def only_prefill(self):
250254

251255
return if_only_prefill
252256

257+
def collect_distributed_status(self):
258+
"""
259+
Collect distributed status
260+
"""
261+
dist_status_list = []
262+
dist_status_obj = DistributedStatus()
263+
dist_out = DistributedOut()
264+
265+
prefill_exists = None
266+
if_only_decode = True
267+
# mix ep in single node
268+
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
269+
prefill_exists = self.exist_prefill()
270+
dist_status_obj.only_decode = not prefill_exists
271+
272+
# whether chunked moe
273+
if self.fd_config.parallel_config.enable_chunked_moe:
274+
chunk_size = self.fd_config.parallel_config.chunked_moe_size
275+
token_num = self.share_inputs["ids_remove_padding"].shape[0]
276+
277+
if token_num > chunk_size:
278+
self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
279+
else:
280+
self.fd_config.parallel_config.moe_num_chunk = 1
281+
282+
dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk
283+
284+
# only ep need to collect and sync distributed status
285+
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
286+
# call once to gather all status
287+
paddle.distributed.all_gather_object(dist_status_list, dist_status_obj)
288+
289+
# Update Batch type for cuda graph for if_only_decode
290+
if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list)
291+
292+
if_only_decode = if_only_decode and not (
293+
prefill_exists if prefill_exists is not None else self.exist_prefill()
294+
)
295+
296+
max_moe_num_chunk = None
297+
if self.fd_config.parallel_config.enable_chunked_moe:
298+
max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list)
299+
300+
dist_out = DistributedOut(
301+
if_only_decode=if_only_decode,
302+
max_moe_num_chunk=max_moe_num_chunk,
303+
)
304+
305+
return dist_out
306+
253307
def only_decode(self):
254308
"""
255309
check whether decode only
@@ -1355,7 +1409,7 @@ def get_model(self) -> nn.Layer:
13551409

13561410
def initialize_forward_meta(self, is_dummy_or_profile_run=False):
13571411
"""
1358-
Initialize forward meta and attention meta data
1412+
Initialize forward meta, attention meta data and update some config.
13591413
"""
13601414
# Initialize forward meta
13611415
self.forward_meta = ForwardMeta(
@@ -1386,8 +1440,12 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False):
13861440
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
13871441
)
13881442

1389-
# Update Batch type for cuda graph for only_decode_batch
1390-
if_only_decode = self.only_decode()
1443+
dist_status = self.collect_distributed_status()
1444+
1445+
if_only_decode = dist_status.if_only_decode
1446+
if self.fd_config.parallel_config.enable_chunked_moe:
1447+
self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk
1448+
13911449
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
13921450

13931451
# Update config about moe for better performance

fastdeploy/worker/model_runner_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""
1616

1717
from abc import ABC, abstractmethod
18+
from dataclasses import dataclass
19+
from typing import Optional
1820

1921
from paddle import nn
2022

@@ -25,6 +27,18 @@
2527
logger = get_logger("model_runner_base", "model_runner_base.log")
2628

2729

30+
@dataclass
31+
class DistributedStatus:
32+
only_decode: bool = True
33+
moe_num_chunk: int = 1
34+
35+
36+
@dataclass
37+
class DistributedOut:
38+
if_only_decode: bool = True
39+
max_moe_num_chunk: Optional[int] = None
40+
41+
2842
class ModelRunnerBase(ABC):
2943
"""
3044
Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model

fastdeploy/worker/worker_process.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,17 @@ def parse_args():
720720
action="store_true",
721721
help="enable expert parallel",
722722
)
723+
parser.add_argument(
724+
"--enable_chunked_moe",
725+
action="store_true",
726+
help="enable chunked moe",
727+
)
728+
parser.add_argument(
729+
"--chunked_moe_size",
730+
type=int,
731+
default=256,
732+
help="chunk size of moe input",
733+
)
723734
parser.add_argument("--ori_vocab_size", type=int, default=None)
724735
parser.add_argument("--think_end_id", type=int, default=-1)
725736
parser.add_argument("--image_patch_id", type=int, default=-1)

tests/ci_use/XPU_45T/run_w4a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_w4a8():
3636
)
3737
print(response.choices[0].message.content)
3838
# print(base_response)
39-
assert any(keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言"])
39+
assert any(keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "小度"])
4040

4141

4242
if __name__ == "__main__":

0 commit comments

Comments
 (0)