Skip to content

Commit f72be7a

Browse files
kevincheng2ltd0924Jiang-Jia-Jun
authored
[BUG] fix ep bug (#4275)
* fix ep bug * update code * update code * update code * [BugFix] fix config bugs (#4370) * Update expert_service.py * Update common_engine.py * Update expert_service.py * Update expert_service.py * Update expert_service.py --------- Co-authored-by: Jiang-Jia-Jun <[email protected]> * update code --------- Co-authored-by: ltd0924 <[email protected]> Co-authored-by: Jiang-Jia-Jun <[email protected]>
1 parent 5abf597 commit f72be7a

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

fastdeploy/scheduler/splitwise_scheduler.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,19 @@ def __init__(
7979
self.writer_parallel = writer_parallel
8080
self.writer_batch_size = writer_batch_size
8181

82+
self.max_model_len = kwargs.get("max_model_len")
83+
self.enable_chunked_prefill = kwargs.get("enable_chunked_prefill")
84+
self.max_num_partial_prefills = kwargs.get("max_num_partial_prefills")
85+
self.max_long_partial_prefills = kwargs.get("max_long_partial_prefills")
86+
self.long_prefill_token_threshold = kwargs.get("long_prefill_token_threshold")
87+
88+
assert self.enable_chunked_prefill is not None, "enable_chunked_prefill must be set"
89+
assert self.max_num_partial_prefills is not None, "max_num_partial_prefills must be set"
90+
assert self.max_long_partial_prefills is not None, "max_long_partial_prefills must be set"
91+
if self.long_prefill_token_threshold is None or self.long_prefill_token_threshold == 0:
92+
assert self.max_model_len is not None, "max_model_len must be set"
93+
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
94+
8295
def check(self):
8396
"""check argument"""
8497
pass
@@ -674,6 +687,7 @@ class InferScheduler:
674687
"""
675688

676689
def __init__(self, config):
690+
self.config = config
677691
self.nodeid = config.nodeid
678692
self.writer_parallel = config.writer_parallel
679693
self.writer_batch_size = config.writer_batch_size
@@ -792,9 +806,13 @@ def get_requests(
792806
reqs = []
793807
required_blocks = 0
794808
current_prefill_tokens = 0
809+
long_partial_requests, short_partial_requests = 0, 0
795810
cur_time = time.time()
796811
for i in range(batch):
797812
try:
813+
if len(self.reqs_queue) == 0:
814+
break
815+
798816
req = self.reqs_queue.popleft()
799817
if cur_time - req.arrival_time > self.ttl:
800818
logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests")
@@ -803,9 +821,27 @@ def get_requests(
803821
current_prefill_tokens += req.prompt_token_ids_len
804822
required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size
805823
required_blocks += required_input_blocks + reserved_output_blocks
806-
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
824+
if required_blocks > available_blocks:
807825
self.reqs_queue.appendleft(req)
808826
return reqs
827+
828+
if self.config.enable_chunked_prefill:
829+
if req.prompt_token_ids_len > self.config.long_prefill_token_threshold:
830+
# long partial requests
831+
long_partial_requests += 1
832+
if long_partial_requests > self.config.max_long_partial_prefills:
833+
self.reqs_queue.appendleft(req)
834+
break
835+
else:
836+
short_partial_requests += 1
837+
838+
if short_partial_requests + long_partial_requests > self.config.max_num_partial_prefills:
839+
self.reqs_queue.appendleft(req)
840+
break
841+
else:
842+
if current_prefill_tokens > max_num_batched_tokens:
843+
self.reqs_queue.appendleft(req)
844+
break
809845
# logger.info(f"Get Requests from Scheduler: {req.request_id}")
810846
reqs.append(req)
811847
except Exception as e:

0 commit comments

Comments
 (0)