@@ -79,6 +79,19 @@ def __init__(
7979 self .writer_parallel = writer_parallel
8080 self .writer_batch_size = writer_batch_size
8181
82+ self .max_model_len = kwargs .get ("max_model_len" )
83+ self .enable_chunked_prefill = kwargs .get ("enable_chunked_prefill" )
84+ self .max_num_partial_prefills = kwargs .get ("max_num_partial_prefills" )
85+ self .max_long_partial_prefills = kwargs .get ("max_long_partial_prefills" )
86+ self .long_prefill_token_threshold = kwargs .get ("long_prefill_token_threshold" )
87+
88+ assert self .enable_chunked_prefill is not None , "enable_chunked_prefill must be set"
89+ assert self .max_num_partial_prefills is not None , "max_num_partial_prefills must be set"
90+ assert self .max_long_partial_prefills is not None , "max_long_partial_prefills must be set"
91+ if self .long_prefill_token_threshold is None or self .long_prefill_token_threshold == 0 :
92+ assert self .max_model_len is not None , "max_model_len must be set"
93+ self .long_prefill_token_threshold = int (self .max_model_len * 0.04 )
94+
8295 def check (self ):
8396 """check argument"""
8497 pass
@@ -674,6 +687,7 @@ class InferScheduler:
674687 """
675688
676689 def __init__ (self , config ):
690+ self .config = config
677691 self .nodeid = config .nodeid
678692 self .writer_parallel = config .writer_parallel
679693 self .writer_batch_size = config .writer_batch_size
@@ -792,9 +806,13 @@ def get_requests(
792806 reqs = []
793807 required_blocks = 0
794808 current_prefill_tokens = 0
809+ long_partial_requests , short_partial_requests = 0 , 0
795810 cur_time = time .time ()
796811 for i in range (batch ):
797812 try :
813+ if len (self .reqs_queue ) == 0 :
814+ break
815+
798816 req = self .reqs_queue .popleft ()
799817 if cur_time - req .arrival_time > self .ttl :
800818 logger .error (f"req({ req .request_id } ) is expired({ self .ttl } ) when InferScheduler Get Requests" )
@@ -803,9 +821,27 @@ def get_requests(
803821 current_prefill_tokens += req .prompt_token_ids_len
804822 required_input_blocks = (req .prompt_token_ids_len + block_size - 1 ) // block_size
805823 required_blocks += required_input_blocks + reserved_output_blocks
806- if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens :
824+ if required_blocks > available_blocks :
807825 self .reqs_queue .appendleft (req )
808826 return reqs
827+
828+ if self .config .enable_chunked_prefill :
829+ if req .prompt_token_ids_len > self .config .long_prefill_token_threshold :
830+ # long partial requests
831+ long_partial_requests += 1
832+ if long_partial_requests > self .config .max_long_partial_prefills :
833+ self .reqs_queue .appendleft (req )
834+ break
835+ else :
836+ short_partial_requests += 1
837+
838+ if short_partial_requests + long_partial_requests > self .config .max_num_partial_prefills :
839+ self .reqs_queue .appendleft (req )
840+ break
841+ else :
842+ if current_prefill_tokens > max_num_batched_tokens :
843+ self .reqs_queue .appendleft (req )
844+ break
809845 # logger.info(f"Get Requests from Scheduler: {req.request_id}")
810846 reqs .append (req )
811847 except Exception as e :
0 commit comments