|
95 | 95 | from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp |
96 | 96 | from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling |
97 | 97 | from fastdeploy.output.pooler import PoolerOutput |
98 | | -from fastdeploy.worker.model_runner_base import ModelRunnerBase |
| 98 | +from fastdeploy.worker.model_runner_base import ( |
| 99 | + DistributedOut, |
| 100 | + DistributedStatus, |
| 101 | + ModelRunnerBase, |
| 102 | +) |
99 | 103 | from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput |
100 | 104 |
|
101 | 105 |
|
@@ -250,6 +254,56 @@ def only_prefill(self): |
250 | 254 |
|
251 | 255 | return if_only_prefill |
252 | 256 |
|
| 257 | + def collect_distributed_status(self): |
| 258 | + """ |
| 259 | + Collect distributed status |
| 260 | + """ |
| 261 | + dist_status_list = [] |
| 262 | + dist_status_obj = DistributedStatus() |
| 263 | + dist_out = DistributedOut() |
| 264 | + |
| 265 | + prefill_exists = None |
| 266 | + if_only_decode = True |
| 267 | + # mix ep in single node |
| 268 | + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": |
| 269 | + prefill_exists = self.exist_prefill() |
| 270 | + dist_status_obj.only_decode = not prefill_exists |
| 271 | + |
| 272 | + # whether chunked moe |
| 273 | + if self.fd_config.parallel_config.enable_chunked_moe: |
| 274 | + chunk_size = self.fd_config.parallel_config.chunked_moe_size |
| 275 | + token_num = self.share_inputs["ids_remove_padding"].shape[0] |
| 276 | + |
| 277 | + if token_num > chunk_size: |
| 278 | + self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size |
| 279 | + else: |
| 280 | + self.fd_config.parallel_config.moe_num_chunk = 1 |
| 281 | + |
| 282 | + dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk |
| 283 | + |
| 284 | + # only ep need to collect and sync distributed status |
| 285 | + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": |
| 286 | + # call once to gather all status |
| 287 | + paddle.distributed.all_gather_object(dist_status_list, dist_status_obj) |
| 288 | + |
| 289 | + # Update Batch type for cuda graph for if_only_decode |
| 290 | + if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list) |
| 291 | + |
| 292 | + if_only_decode = if_only_decode and not ( |
| 293 | + prefill_exists if prefill_exists is not None else self.exist_prefill() |
| 294 | + ) |
| 295 | + |
| 296 | + max_moe_num_chunk = None |
| 297 | + if self.fd_config.parallel_config.enable_chunked_moe: |
| 298 | + max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list) |
| 299 | + |
| 300 | + dist_out = DistributedOut( |
| 301 | + if_only_decode=if_only_decode, |
| 302 | + max_moe_num_chunk=max_moe_num_chunk, |
| 303 | + ) |
| 304 | + |
| 305 | + return dist_out |
| 306 | + |
253 | 307 | def only_decode(self): |
254 | 308 | """ |
255 | 309 | check whether decode only |
@@ -1355,7 +1409,7 @@ def get_model(self) -> nn.Layer: |
1355 | 1409 |
|
1356 | 1410 | def initialize_forward_meta(self, is_dummy_or_profile_run=False): |
1357 | 1411 | """ |
1358 | | - Initialize forward meta and attention meta data |
| 1412 | + Initialize forward meta, attention meta data and update some config. |
1359 | 1413 | """ |
1360 | 1414 | # Initialize forward meta |
1361 | 1415 | self.forward_meta = ForwardMeta( |
@@ -1386,8 +1440,12 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): |
1386 | 1440 | kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], |
1387 | 1441 | ) |
1388 | 1442 |
|
1389 | | - # Update Batch type for cuda graph for only_decode_batch |
1390 | | - if_only_decode = self.only_decode() |
| 1443 | + dist_status = self.collect_distributed_status() |
| 1444 | + |
| 1445 | + if_only_decode = dist_status.if_only_decode |
| 1446 | + if self.fd_config.parallel_config.enable_chunked_moe: |
| 1447 | + self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk |
| 1448 | + |
1391 | 1449 | only_decode_use_cudagraph = self.use_cudagraph and if_only_decode |
1392 | 1450 |
|
1393 | 1451 | # Update config about moe for better performance |
|
0 commit comments