diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 7180242c..08a77267 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -211,7 +211,7 @@ def __init__( ) @abstractmethod - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" @abstractmethod @@ -299,6 +299,11 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: if recv_req is not None and len(recv_req) > 0: for req in recv_req: if req.hidden_states is not None: + size_attr = getattr(req.hidden_states, "size", None) + hidden_size = size_attr() if callable(size_attr) else size_attr + logger.debug( + f"recv request {req.request_id} hidden_states.length: {hidden_size}" + ) if req.hidden_states.dtype != self.dtype: logger.debug( f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" @@ -386,9 +391,15 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> List[Request]: + ) -> Tuple[List[Request], List[Request]]: """Prepares a batch of requests for the next stage of the pipeline. + Returns two lists: + - chunked_reqs: requests that still have chunks to complete (e.g. chunked prefill); + these should be handled locally via handle_input_requests only, not sent to next peer. + - to_forward_reqs: requests to forward (single node: handle_input_requests; + multi-node and tp_rank==0: request_to_proto and send to next peer). + Args: requests: List of requests in the batch batch_output: Output from process_batch. Always a dict with: @@ -403,7 +414,8 @@ def prepare_next_batch_requests( hidden_states = batch_output["hidden_states"] token_probs = batch_output["probs"] - batched_requests = [] + chunked_reqs = [] + to_forward_reqs = [] pre_length = 0 for i, src_request in enumerate(requests): if self.is_last_peer: @@ -437,9 +449,9 @@ def prepare_next_batch_requests( next_req = self._prepare_next_single_request( src_request, hidden_state_for_req, token_prob ) - batched_requests.append(next_req) + to_forward_reqs.append(next_req) - return batched_requests + return chunked_reqs, to_forward_reqs def release_and_evict_request(self, rid: str): """Release per-request resources and evict from scheduler. Best-effort, never raises.""" @@ -471,7 +483,7 @@ def run_loop(self): if self.enable_weight_refit: self.check_and_refit_weight(refit_weight_path) - self.handle_input_requests(received_requests) + self.handle_input_requests(received_requests, from_previous_peer=True) # Send abort signals to P2P server to broadcast to all nodes if len(self.finished_batch) > 0 and self.tp_rank == 0: self.send_to_peer_socket.send_multipart( @@ -547,28 +559,62 @@ def run_loop(self): except Exception: pass # 7. Prepare requests for the next stage in the pipeline - next_batch = self.prepare_next_batch_requests( + chunked_reqs, to_forward_reqs = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) - + if chunked_reqs: + logger.debug(f"Handle {len(chunked_reqs)} chunked requests.") + self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination - if self.is_last_peer and self.is_first_peer: - # Single node: handle locally - self.handle_input_requests(next_batch) - elif self.tp_rank == 0: - # Send output to next peer - self.send_to_peer_socket.send_multipart( - [ - b"forward", - request_to_proto(next_batch, self.device).SerializeToString(), - ] - ) - logger.debug( - f"Processed batch of type {batch_type} with {len(next_batch)} requests " - f"in {(time.time() - start_time) * 1000:.3f} ms" - ) + if to_forward_reqs or chunked_reqs: + logger.debug(f"dispatch to_forward and chunked_reqs to next peer.") + if ( + self.is_last_peer + and self.is_first_peer + and (to_forward_reqs is not None and len(to_forward_reqs) > 0) + ): + # Single node: handle to_forward locally + logger.debug(f"Handle {len(to_forward_reqs)} to_forward requests.") + self.handle_input_requests(to_forward_reqs) + elif self.tp_rank == 0: + # Send to_forward to next peer (do not send chunked_reqs if self is last_peer) + logger.debug( + f"self is last_peer: {self.is_last_peer}, self is first_peer: {self.is_first_peer}" + ) + if not self.is_last_peer: + logger.debug( + f"Send {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs to next peer." + ) + self.send_to_peer_socket.send_multipart( + [ + b"forward", + request_to_proto( + to_forward_reqs + chunked_reqs, self.device + ).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of with {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs " + f"in {(time.time() - start_time) * 1000:.3f} ms" + ) + elif to_forward_reqs is not None and len(to_forward_reqs) > 0: + logger.debug( + f"Send {len(to_forward_reqs)} to_forward to next peer." + ) + self.send_to_peer_socket.send_multipart( + [ + b"forward", + request_to_proto( + to_forward_reqs, self.device + ).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of with {len(to_forward_reqs)} to_forward " + f"in {(time.time() - start_time) * 1000:.3f} ms" + ) except Exception as e: logger.exception(f"Error processing batch: {e}") @@ -744,7 +790,7 @@ def _prepare_next_single_request( request_id=request.request_id, status=RequestStatus.DECODING, current_position=request.total_length + 1, - input_ids=request.input_ids, + input_ids=request.origin_input_ids, hidden_states=hidden_states, next_token_id=next_token_id, routing_table=request.routing_table, @@ -763,7 +809,7 @@ def _prepare_next_single_request( request_id=request.request_id, status=RequestStatus.DECODING, # Last peer always changes status to DECODING current_position=request.total_length, - input_ids=request.input_ids, + input_ids=request.origin_input_ids, hidden_states=hidden_states, next_token_id=next_token_id, routing_table=request.routing_table, diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index 20ed93e4..cf2fde38 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -50,6 +50,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None, conn=Non "max_loaded_loras": args.max_loaded_loras, "enable_weight_refit": args.enable_weight_refit, "weight_refit_mode": args.weight_refit_mode, + "chunked_prefill_size": args.chunked_prefill_size, } if args.gpu_backend == "sglang": diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 268e035c..6fcf2c2c 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -20,6 +20,7 @@ ) from parallax.server.sampling.sampler import SamplingBatchInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.mac_prefill_addr import AddReqResult, MACPrefillAdder from parallax.utils.utils import ( combine_padding_and_causal_masks, create_causal_mask, @@ -91,6 +92,8 @@ def __init__( # Weight Refit enable_weight_refit: Optional[bool] = False, weight_refit_mode: Optional[str] = "disk", + # Chunked prefill (SGL-only; accepted here for factory config compatibility) + chunked_prefill_size: Optional[int] = None, # Pipe communication conn: Optional[List[Any]] = [], ): @@ -232,14 +235,18 @@ def __init__( # Prefix Cache Manager self.enable_prefix_cache = enable_prefix_cache - # self.prefix_cache = RadixCache( - # num_kv_heads=num_key_value_heads, - # head_dim=head_dim, - # head_dim_v=v_head_dim, - # num_layers=self.num_shard_layers, - # dtype=self.dtype, - # page_size=1, - # ) + if chunked_prefill_size is not None and chunked_prefill_size > 0: + # up align to page size + self.chunked_prefill_size = ( + (chunked_prefill_size + self.cache_manager.block_size - 1) + // self.cache_manager.block_size + * self.cache_manager.block_size + ) + else: + self.chunked_prefill_size = ( + max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch + ) + self.chunked_req = None logger.debug( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) @@ -270,7 +277,7 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj): data = pickle.loads(np.array(data_arr).tobytes()) return data - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) @@ -302,7 +309,14 @@ def handle_input_requests(self, requests: List[Request]): continue if not req.abort and req.next_token_id is not None: - original_req.commit_new_token(req.next_token_id) + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + original_req.status = RequestStatus.PREFILLING + else: + original_req.commit_new_token(req.next_token_id) if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -311,7 +325,16 @@ def handle_input_requests(self, requests: List[Request]): if req.abort: original_req.abort = True - if self.scheduler.check_and_update_request_status(original_req): + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + self.chunked_req.is_chunked -= 1 + self.cache_manager.release_request(original_req.request_id) + self.scheduler.enque_request(original_req) + continue + elif self.scheduler.check_and_update_request_status(original_req): self.cache_manager.release_request(original_req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " @@ -355,11 +378,6 @@ def handle_input_requests(self, requests: List[Request]): req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - if self.enable_prefix_cache: - keys, values = self.cache_manager.gather_kv_cache(req.request_id) - self.prefix_cache.cache_finished_request(req, keys, values) - self.prefix_cache.evict_request(req.request_id) - self.cache_manager.release_request(req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " @@ -368,10 +386,49 @@ def handle_input_requests(self, requests: List[Request]): self.scheduler.evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) + elif ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + and not from_previous_peer + ): + self.chunked_req.is_chunked -= 1 + req.status = RequestStatus.PREFILLING + self.cache_manager.release_request(req.request_id) + self.scheduler.evict_request(req.request_id) + continue else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) + def prepare_next_batch_requests( + self, requests: List[Request], batch_output: Any, context_lengths: Any + ) -> Tuple[List[Request], List[Request]]: + """Prepares a batch of requests for the next stage of the pipeline.""" + base_chunked, base_to_forward = super().prepare_next_batch_requests( + requests, batch_output, context_lengths + ) + if ( + self.chunked_req is None + or self.chunked_req.is_chunked <= 0 + or self.chunked_req.rid not in [req.request_id for req in requests] + ): + logger.debug( + f"mlx_executor: prepare_next_batch_requests: return base_chunked{len(base_chunked)} and base_to_forward{len(base_to_forward)} because chunked_req is None or is_chunked <= 0 or rid not in requests" + ) + return base_chunked, base_to_forward + chunked_rid = self.chunked_req.rid + for req in base_to_forward: + if req.request_id == chunked_rid: + req.status = RequestStatus.PREFILLING + base_chunked.append(req) + break + base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + logger.debug( + f"mlx_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests" + ) + return base_chunked, base_to_forward + def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): """Process a batch of requests in MLX.""" # Run model and get updated cache @@ -521,6 +578,53 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A if batch_size == 0: return None + original_batched_requests = batched_requests + logger.debug(f"original_batched_requests_size: {len(original_batched_requests)}") + adder = MACPrefillAdder( + self.cache_manager.block_size, self.chunked_prefill_size, self.cache_manager + ) + + chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None + + if self.chunked_req is not None and chunked_rid in [ + req.request_id for req in original_batched_requests + ]: + # update chunked_req use old_req + self.chunked_req = [ + req for req in original_batched_requests if req.request_id == chunked_rid + ][0] + logger.debug(f"before add_chunked_req, chunked_req is not None") + self.chunked_req = adder.add_chunked_req(self.chunked_req) + if self.chunked_req is None: + logger.debug( + f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None" + ) + + for old_req in original_batched_requests: + if chunked_rid is not None and old_req.request_id == chunked_rid: + continue + res = adder.add_one_req(old_req) + if res != AddReqResult.CONTINUE: + logger.debug(f"macprefilladder has no token to add to prefill batch") + break + + if adder.new_chunked_req is not None: + self.chunked_req = adder.new_chunked_req + logger.debug(f"new chunked_req is {self.chunked_req.rid}") + + if self.chunked_req is not None and self.chunked_req.rid in [ + req.request_id for req in original_batched_requests + ]: + self.chunked_req.is_chunked += 1 + + can_run_by_id = {req.request_id: req for req in adder.can_run_list} + batched_requests = [ + can_run_by_id[req.request_id] + for req in original_batched_requests + if req.request_id in can_run_by_id + ] + logger.debug(f"after add_one_req, batched_requests size: {len(batched_requests)}") + h_or_tokens_list = [] block_tables_list = [] context_lengths_list = [] @@ -535,7 +639,9 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A token_ids = None if self.enable_prefix_cache and req.input_ids is not None: token_ids = req.input_ids - + logger.debug( + f"before allocate_request: {req.request_id}, req.total_length: {req.total_length}" + ) success, matched_tokens = self.cache_manager.allocate_request( req.request_id, req.total_length, token_ids=token_ids ) @@ -569,8 +675,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A actual_processed_lengths_list.append(len(req.input_ids)) else: if matched_tokens > 0 and self.enable_prefix_cache: - # Skip the prefix hidden states that correspond to cached tokens - new_hidden = req.hidden_states[matched_tokens:] + keep_len = req.total_length - matched_tokens + new_hidden = req.hidden_states[-keep_len:] if new_hidden.shape[0] == 0: # All tokens cached - keep the last hidden state new_hidden = req.hidden_states[-1:] diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 555db81c..8f18507d 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -7,7 +7,7 @@ import torch from sglang.srt.lora.lora_registry import LoRARef -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import PPProxyTensors @@ -91,6 +91,8 @@ def __init__( weight_refit_mode: Optional[str] = "disk", # Pipe communication conn: Optional[List[Any]] = [], + chunked_prefill_size: Optional[int] = None, + max_prefill_tokens: Optional[int] = 4 * 1024, ): self.enable_lora = True if lora_paths is not None else enable_lora @@ -102,7 +104,21 @@ def __init__( self.lora_eviction_policy = lora_eviction_policy self.lora_backend = lora_backend self.max_lora_chunk_size = max_lora_chunk_size - + self.chunked_prefill_size = chunked_prefill_size + self.max_prefill_tokens = max_prefill_tokens + # Chunked prefill must be at least one page for correct KV/prefix cache alignment. + if self.chunked_prefill_size is not None and self.chunked_prefill_size < kv_block_size: + logger.info( + f"chunked_prefill_size {self.chunked_prefill_size} < page_size (kv_block_size={kv_block_size}); " + f"clamping to {kv_block_size}" + ) + self.chunked_prefill_size = kv_block_size + elif self.chunked_prefill_size is not None: + self.chunked_prefill_size = chunked_prefill_size + else: + self.chunked_prefill_size = ( + max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch + ) if self.lora_paths is not None and len(self.lora_paths) > 0: self.check_lora_server_args() @@ -136,6 +152,7 @@ def __init__( "lora_eviction_policy": self.lora_eviction_policy, "lora_backend": self.lora_backend, "max_lora_chunk_size": self.max_lora_chunk_size, + "chunked_prefill_size": self.chunked_prefill_size, } logger.debug( f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -183,6 +200,11 @@ def __init__( self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group + # add chunked_prefill_size and chunked_req + if chunked_prefill_size is not None and chunked_prefill_size <= 0: + chunked_prefill_size = None + self.chunked_req = None + # create a page tree cache for sglang prefill if enable_prefix_cache: cache_params = CacheInitParams( @@ -302,7 +324,53 @@ def check_lora_server_args(self): and (self.max_lora_chunk_size & (self.max_lora_chunk_size - 1)) == 0 ), "--max-lora-chunk-size must be a power of 2 between 16 and 128." - def handle_input_requests(self, requests: List[Request]): + def stash_chunked_request(self, req: Req): + # #endregion + if req.req_pool_idx is None: + logger.warning( + "stash_chunked_request: skipping cache_unfinished_req and free because req.req_pool_idx is None (rid=%s, fill_ids_len=%s)", + getattr(req, "rid", None), + len(req.fill_ids), + ) + return + self.page_tree_cache.cache_unfinished_req(req, chunked=True) + # FIX: below code don't now if it is needed + # # Chunked request keeps its rid but will get a new req_pool_idx + if self.model_runner.mambaish_config is not None: + self.model_runner.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=False) + else: + self.model_runner.req_to_token_pool.free(req.req_pool_idx) + + def prepare_next_batch_requests( + self, requests: List[Request], batch_output: Any, context_lengths: Any + ) -> Tuple[List[Request], List[Request]]: + """Split out chunked prefill reqs; set their status to PREFILLING and return (chunked, to_forward).""" + base_chunked, base_to_forward = super().prepare_next_batch_requests( + requests, batch_output, context_lengths + ) + if ( + self.chunked_req is None + or self.chunked_req.is_chunked <= 0 + or self.chunked_req.rid not in [req.request_id for req in requests] + ): + logger.debug( + f"sglang_executor: prepare_next_batch_requests: return base_chunked and base_to_forward because chunked_req is None or is_chunked <= 0 or rid not in requests" + ) + return base_chunked, base_to_forward + chunked_rid = self.chunked_req.rid + self.stash_chunked_request(self.chunked_req) + for req in base_to_forward: + if req.request_id == chunked_rid: + req.status = RequestStatus.PREFILLING + base_chunked.append(req) + break + base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + logger.debug( + f"sglang_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests" + ) + return base_chunked, base_to_forward + + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) @@ -331,11 +399,19 @@ def handle_input_requests(self, requests: List[Request]): # If it's an abort signal (e.g. from OOM), next_token_id might be None or dummy if not req.abort and req.next_token_id is not None: - original_req.commit_new_token(req.next_token_id) - logger.debug( - f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " - f"output_ids now has {len(original_req.output_ids)} tokens" - ) + # Keep chunked_req in PREFILLING so next round it goes to prefill again. + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + original_req.status = RequestStatus.PREFILLING + else: + original_req.commit_new_token(req.next_token_id) + logger.debug( + f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " + f"output_ids now has {len(original_req.output_ids)} tokens" + ) if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -344,7 +420,18 @@ def handle_input_requests(self, requests: List[Request]): if req.abort: original_req.abort = True - if self.scheduler.check_and_update_request_status(original_req): + # Chunked req is held by executor.chunked_req and enters next prefill via + # add_chunked_req; do not re-enqueue until all chunks are done. + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + self.chunked_req.is_chunked -= 1 + self.scheduler.enque_request(original_req) + continue + + elif self.scheduler.check_and_update_request_status(original_req): logger.debug(f"Releasing resources for finished request {req.request_id}") self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: @@ -387,6 +474,16 @@ def handle_input_requests(self, requests: List[Request]): self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) + elif ( + self.chunked_req is not None + and self.chunked_req.rid == req.request_id + and self.chunked_req.is_chunked > 0 + and not from_previous_peer + ): + self.chunked_req.is_chunked -= 1 + req.status = RequestStatus.PREFILLING + self.scheduler.evict_request(req.request_id) + continue else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) @@ -410,7 +507,31 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: logits_output = out.logits_output # Merge prefill batch into running batch + chunked_req_to_exclude = set() + + if self.chunked_req is not None and self.chunked_req.is_chunked > 0: + # Move the chunked request out of the batch so that we can merge + # only finished requests to running_batch. + logger.debug(f"exclude chunked_req {self.chunked_req.rid} from running_batch") + chunked_req_to_exclude.add(self.chunked_req) + + if self.cur_batch and self.cur_batch.forward_mode.is_extend(): + if self.cur_batch.chunked_req is not None and self.cur_batch.chunked_req.is_chunked > 0: + chunked_req_to_exclude.add(self.cur_batch.chunked_req) + if self.cur_batch: + # Avoid duplicate rids in running_batch: exclude any cur_batch req + # whose rid is already in running_batch (filter_batch uses object identity). + if not self.running_batch.is_empty(): + existing_rids = {req.rid for req in self.running_batch.reqs} + for req in list(self.cur_batch.reqs): + if req.rid in existing_rids: + chunked_req_to_exclude.add(req) + cur_batch_size = self.cur_batch.batch_size() + self.cur_batch.filter_batch(chunked_req_to_exclude=list[Req](chunked_req_to_exclude)) + + if self.cur_batch.batch_size() < cur_batch_size: + self.cur_batch.batch_is_full = False if self.cur_batch.forward_mode.is_extend(): # Merge the new batch into the running batch if not self.cur_batch.is_empty(): @@ -454,7 +575,32 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: def _release_request(self, rid: str): """Release per-request resources in SGLang.""" try: + # Debug: log running_batch before release to verify request eviction + before_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + before_rids = ( + [] + if self.running_batch.is_empty() + else [req.rid for req in self.running_batch.reqs] + ) + logger.debug( + "[ChunkedPrefill-Debug] _release_request BEFORE: releasing rid=%s, " + "running_batch size=%s rids=%s", + rid, + before_size, + before_rids, + ) release_sglang_request(self.running_batch, rid) + after_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + after_rids = ( + [] + if self.running_batch.is_empty() + else [req.rid for req in self.running_batch.reqs] + ) + logger.debug( + "[ChunkedPrefill-Debug] _release_request AFTER: running_batch size=%s rids=%s", + after_size, + after_rids, + ) except Exception: pass @@ -542,7 +688,12 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A return None # Pre-check: Verify KV cache has enough space for prefill - total_tokens_needed = sum(req.total_length for req in batched_requests) + def _get_total_length(req: Request) -> int: + if hasattr(req, "hidden_states") and req.hidden_states is not None: + return req.hidden_states.shape[0] + return req.total_length + + total_tokens_needed = sum(_get_total_length(req) for req in batched_requests) if not self._check_kv_cache_available(total_tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, @@ -592,13 +743,12 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A if self.lora_paths and len(self.lora_paths) > 0 else None ) - lengths.append(req.total_length) + lengths.append(_get_total_length(req)) lengths_tensor = torch.tensor(lengths, device=self.device) schedule_batch, forward_batch = form_sgl_batch_prefill( batched_requests, - self.model_runner, - self.page_tree_cache, + self, ) self.cur_batch = schedule_batch diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index d4350fe8..98be6c63 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -90,6 +90,8 @@ def __init__( weight_refit_mode: Optional[str] = "disk", # Routed experts enable_return_routed_experts: bool = False, + # Chunked prefill (SGL-only; accepted here for factory config compatibility) + chunked_prefill_size: Optional[int] = None, # Pipe communication conn: Optional[List[Any]] = [], ): @@ -205,7 +207,7 @@ def check_lora_server_args(self): "--enable-lora is set to False, any provided lora_paths will be ignored." ) - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f7c9bb90..62ae59dd 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -109,6 +109,11 @@ def __init__( self.last_updated_time: Optional[float] = None self.lora_id: Optional[str] = None self.lora_path = lora_path + self.is_chunked = 0 + self.rid = self.request_id + self.origin_input_ids = input_ids + # When set (e.g. by chunked prefill), total_length property returns this instead of computed value. + self._effective_total_length: Optional[int] = None @property def is_finished(self) -> bool: @@ -199,7 +204,9 @@ def output_length(self) -> int: @property def total_length(self) -> int: - """Total length of the sequence (input + output).""" + """Total length of the sequence (input + output). Overridable via _effective_total_length (e.g. chunked prefill).""" + if self._effective_total_length is not None: + return self._effective_total_length return self.prompt_len + self.output_length def get_model_input_for_first_peer(self) -> List[int]: @@ -300,7 +307,9 @@ def input_length(self) -> int: @property def total_length(self) -> int: - """Total length of the sequence (input + output).""" + """Total length of the sequence (input + output). Overridable via _effective_total_length (e.g. chunked prefill).""" + if self._effective_total_length is not None: + return self._effective_total_length return self.current_position @classmethod @@ -335,7 +344,7 @@ def from_initial_request( return IntermediateRequest( request_id=initial_request.request_id, status=initial_request.status, - input_ids=initial_request.input_ids, + input_ids=initial_request.origin_input_ids, next_token_id=next_token_id, current_position=initial_request.total_length, hidden_states=hidden_states, @@ -362,7 +371,7 @@ def from_intermediate_request( request_id=old_request.request_id, status=old_request.status, current_position=old_request.total_length, - input_ids=old_request.input_ids, + input_ids=old_request.origin_input_ids, next_token_id=old_request.next_token_id, hidden_states=new_hidden_states, routing_table=old_request.routing_table, diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 69ab5fbd..1c9afef2 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -126,6 +126,8 @@ def enque_request(self, request: Request | str): request.ready_for_next_step = True request.last_updated_time = time.time() + if request.origin_input_ids is not None: + request.input_ids = request.origin_input_ids # TODO: Handle chunked prefill. if request.is_decoding: rid = request.request_id @@ -242,34 +244,27 @@ def admit_requests(self): """Move requests from wait queue into running (inflight) set, up to capacity. Pushes admitted requests directly into the running set. + Each request is updated or added to running at most once per call. """ - while self._wait_queue and len(self._running_requests) < self.max_batch_size: + # One pass over wait_queue to avoid infinite loop when all front items are already in running_requests + initial_len = len(self._wait_queue) + for _ in range(initial_len): + if not self._wait_queue or len(self._running_requests) >= self.max_batch_size: + break req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: + self._wait_queue.append(req) continue - # Check kv cache pool - if self.cache_manager is not None: - if not self.cache_manager.has_request(req.request_id): - # TODO: Handle chunked prefill, and support preemption. - # Pass input_ids for prefix cache matching - token_ids = getattr(req, "input_ids", None) - success, matched_tokens = self.cache_manager.allocate_request( - req.request_id, req.total_length, token_ids=token_ids - ) - if not success: - logger.warning( - f"Request {rid} can't be admit to running batch due to KV cache size." - ) - # Put back to wait queue if allocation fails - self._wait_queue.appendleft(req) - # Stop admitting since we are out of memory - break - if matched_tokens > 0: - logger.debug( - f"Request {rid} matched {matched_tokens} tokens from prefix cache" - ) + # if cache_manager is not None and has allow attribute, break if allow is False + if ( + self.cache_manager is not None + and hasattr(self.cache_manager, "allow") + and not self.cache_manager.allow + ): + logger.debug(f"Cache manager does not allow request {rid},breaking admission.") + break # Add request to running requests self._running_requests[rid] = req @@ -320,7 +315,6 @@ def form_batch(self) -> List[Request]: self.admit_requests() if not self._running_requests: return [] - inflight_tokens = 0 batch: List[Request] = [] @@ -333,13 +327,15 @@ def form_batch(self) -> List[Request]: prefill_candidates.append(req) elif req.is_decoding: decode_candidates.append(req) - # 1) Fill with prefills first for req in prefill_candidates: if len(batch) >= self.micro_batch_size: break cost = req.prompt_len if cost + inflight_tokens > self.max_num_tokens_per_batch: + logger.debug( + f"prefill request {req.request_id} cost {cost} + inflight_tokens {inflight_tokens} > max_num_tokens_per_batch {self.max_num_tokens_per_batch}, breaking" + ) continue batch.append(req) inflight_tokens += cost diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 39303ce5..445bb487 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -109,6 +109,14 @@ def parse_args() -> argparse.Namespace: "--enable-prefix-cache", action="store_true", help="Enable prefix cache reuse" ) + # add --chunked-prefill-size + parser.add_argument( + "--chunked-prefill-size", + type=int, + default=None, + help="Chunk size for chunked prefill processing", + ) + # Scheduler configuration parser.add_argument( "--max-batch-size", @@ -347,6 +355,16 @@ def validate_args(args: argparse.Namespace) -> None: if args.kv_block_size <= 0: raise ValueError("kv_block_size must be positive") + # chunked-prefill 依赖 prefix-cache,未开启 prefix-cache 时不能单独使用 chunked-prefill + chunked_prefill_size = getattr(args, "chunked_prefill_size", None) + if chunked_prefill_size is not None and not args.enable_prefix_cache: + raise ValueError( + "chunked-prefill requires prefix-cache to be enabled. " + "Use --enable-prefix-cache when specifying --chunked-prefill-size." + ) + if chunked_prefill_size is not None and chunked_prefill_size <= 0: + raise ValueError("chunked_prefill_size must be positive") + if args.micro_batch_ratio <= 0: raise ValueError("micro_batch_ratio must be positive") diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 56720040..785ceb04 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -5,11 +5,15 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ +from __future__ import annotations + from types import SimpleNamespace -from typing import List, Optional +from typing import TYPE_CHECKING, List import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder +from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -18,13 +22,15 @@ from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from parallax.server.executor.sglang_executor import PageRadixCache from parallax.server.request import Request from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) from parallax_utils.logging_config import get_logger +if TYPE_CHECKING: + from parallax.server.executor.sglang_executor import SGLExecutor + logger = get_logger(__name__) @@ -47,12 +53,64 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S return params +def _dummy_tree_cache_for_adder(): + """Dummy tree cache when prefix cache is disabled. PrefillAdder expects tree_cache + to have evictable_size(), inc_lock_ref(), dec_lock_ref() etc.; pass this instead of None.""" + return SimpleNamespace( + evictable_size=lambda: 0, + full_evictable_size=lambda: 0, + swa_evictable_size=lambda: 0, + inc_lock_ref=lambda node: None, + dec_lock_ref=lambda node, uuid=None: None, + ) + + def transform_requests_to_sglang( - old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None + old_requests: List[Request], + executor: SGLExecutor, ) -> List[Req]: """Transforms Parallax Request to SGLang.Req format""" + model_runner = executor.model_runner + page_tree_cache = executor.page_tree_cache + chunked_prefill_size = executor.chunked_prefill_size + # When prefix cache is disabled, sglang's PrefillAdder still expects tree_cache to have + # evictable_size(), inc_lock_ref(), etc. Pass a dummy instead of None. + tree_cache_for_adder = ( + page_tree_cache if page_tree_cache is not None else _dummy_tree_cache_for_adder() + ) + # Prefill policy + adder = PrefillAdder( + model_runner.page_size, + tree_cache_for_adder, + model_runner.token_to_kv_pool_allocator, + None, + None, + executor.max_prefill_tokens, + chunked_prefill_size, + 0, + None, + None, + ) + # Save rid of chunked req added this round so we can reorder can_run_list to match old_requests. + chunked_rid = executor.chunked_req.rid if executor.chunked_req is not None else None + + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in old_requests + ]: + logger.debug(f"before add_chunked_req, chunked_req is not None") + executor.chunked_req.init_next_round_input(page_tree_cache) + executor.chunked_req = adder.add_chunked_req(executor.chunked_req) + if executor.chunked_req is None: + logger.debug( + f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None" + ) + reqs = [] + logger.debug(f"old_req size: {len(old_requests)}") for old_req in old_requests: + # Chunked req is added via add_chunked_req above; skip to avoid double-add. + if chunked_rid is not None and old_req.request_id == chunked_rid: + continue sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( rid=old_req.request_id, @@ -72,6 +130,22 @@ def transform_requests_to_sglang( req.init_next_round_input(page_tree_cache) + res = adder.add_one_req( + req, + executor.chunked_req is not None, + None, + ) + + if res != AddReqResult.CONTINUE: + logger.warning( + f"Request {old_req.request_id} failed to add to prefill batch, result: {res},\ + req_len: {len(req.origin_input_ids)}" + ) + if res == AddReqResult.NO_TOKEN: + logger.warning(f"there is no token to add to prefill batch") + executor.running_batch.batch_is_full = True + break + # Debug: Log after cache lookup if page_tree_cache is not None: prefix_indices_len = len(req.prefix_indices) if hasattr(req, "prefix_indices") else 0 @@ -85,17 +159,43 @@ def transform_requests_to_sglang( ) reqs.append(req) - return reqs + + logger.debug(f"new reqs size: {len(reqs)}") + + if adder.new_chunked_req is not None: + # Update chunked prefill + assert executor.chunked_req is None + executor.chunked_req = adder.new_chunked_req + logger.debug(f"new chunked_req is {executor.chunked_req}") + + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in old_requests + ]: + executor.chunked_req.is_chunked += 1 + + # Reorder so returned list follows old_requests order and each element is the Req + # that corresponds to that old_req (same rid). Use rid to map instead of assuming + # can_run_list order, so the relationship with reqs is explicit. + can_run_list = adder.can_run_list + if chunked_rid is None and ( + executor.chunked_req is None + or executor.chunked_req.rid not in [req.request_id for req in old_requests] + ): + return can_run_list + rid_to_req = {req.rid: req for req in can_run_list} + reordered: List[Req] = [rid_to_req[old_req.request_id] for old_req in old_requests] + return reordered def form_sgl_batch_prefill( requests: List[Request], - model_runner: ModelRunner, - page_tree_cache: Optional[PageRadixCache] = None, + executor: SGLExecutor, ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" + model_runner = executor.model_runner + page_tree_cache = executor.page_tree_cache - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache) + sgl_reqs = transform_requests_to_sglang(requests, executor) def dummy_evict(*args): pass @@ -115,8 +215,20 @@ def dummy_evict(*args): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, + chunked_req=( + executor.chunked_req + if executor.chunked_req is not None + and executor.chunked_req.rid in [req.request_id for req in requests] + else None + ), ) schedule_batch.prepare_for_extend() + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in requests + ]: + logger.debug( + f"chunked_req.rid={executor.chunked_req.rid}, chunked_req.req_pool_idx: {executor.chunked_req.req_pool_idx}" + ) num_tokens = schedule_batch.extend_num_tokens dp_size = model_runner.dp_size @@ -199,6 +311,14 @@ def find_index(running_batch: ScheduleBatch, request_id: str): return -1 +def find_index_safe(running_batch: ScheduleBatch, request_id: str) -> int: + """Find first index of request_id in running_batch; return -1 if not found (no log).""" + for index, req in enumerate(running_batch.reqs): + if req.rid == request_id: + return index + return -1 + + def form_sgl_batch_decode( requests: List[Request], model_runner: ModelRunner, @@ -249,42 +369,46 @@ def form_sgl_batch_decode( def release_sglang_request(running_batch: ScheduleBatch, request_id: str): - """Release KV Cache and other resources for finished/aborted requests.""" + """Release KV Cache and other resources for finished/aborted requests. + Removes all entries with the given request_id (handles duplicate rids).""" if running_batch is None or running_batch.is_empty(): return - seq_lens_cpu = running_batch.seq_lens.cpu().numpy() - idx = find_index(running_batch, request_id) - req = running_batch.reqs.pop(idx) - - # use running batch's tree cache to release kv cache - tree_cache = running_batch.tree_cache - - # for completed requests, is_insert=True to insert into prefix cache - # for aborted requests, is_insert=False to not insert into prefix cache - is_insert = True # can be adjusted based on request status - - if isinstance(tree_cache, PageRadixCache): - tree_cache.cache_finished_req(req) - else: - page_size = running_batch.token_to_kv_pool_allocator.page_size - last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size - end_pos = last_uncached_pos + seq_lens_cpu[idx] - running_batch.seq_lens = torch.cat( - (running_batch.seq_lens[:idx], running_batch.seq_lens[idx + 1 :]) - ) - running_batch.seq_lens_cpu = torch.cat( - (running_batch.seq_lens_cpu[:idx], running_batch.seq_lens_cpu[idx + 1 :]) - ) - running_batch.orig_seq_lens = torch.cat( - (running_batch.orig_seq_lens[:idx], running_batch.orig_seq_lens[idx + 1 :]) - ) + while True: + idx = find_index_safe(running_batch, request_id) + if idx < 0: + break + seq_lens_cpu = running_batch.seq_lens.cpu().numpy() + req = running_batch.reqs.pop(idx) + + # use running batch's tree cache to release kv cache + tree_cache = running_batch.tree_cache + + # for completed requests, is_insert=True to insert into prefix cache + # for aborted requests, is_insert=False to not insert into prefix cache + is_insert = True # can be adjusted based on request status + + if isinstance(tree_cache, PageRadixCache): + tree_cache.cache_finished_req(req) + else: + page_size = running_batch.token_to_kv_pool_allocator.page_size + last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size + end_pos = last_uncached_pos + seq_lens_cpu[idx] + running_batch.seq_lens = torch.cat( + (running_batch.seq_lens[:idx], running_batch.seq_lens[idx + 1 :]) + ) + running_batch.seq_lens_cpu = torch.cat( + (running_batch.seq_lens_cpu[:idx], running_batch.seq_lens_cpu[idx + 1 :]) + ) + running_batch.orig_seq_lens = torch.cat( + (running_batch.orig_seq_lens[:idx], running_batch.orig_seq_lens[idx + 1 :]) + ) - # Free kv cache - token_indices = running_batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - last_uncached_pos:end_pos - ] - running_batch.token_to_kv_pool_allocator.free(token_indices) - running_batch.req_to_token_pool.free(req.req_pool_idx) - running_batch.req_pool_indices = torch.cat( - (running_batch.req_pool_indices[:idx], running_batch.req_pool_indices[idx + 1 :]) - ) + # Free kv cache + token_indices = running_batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + last_uncached_pos:end_pos + ] + running_batch.token_to_kv_pool_allocator.free(token_indices) + running_batch.req_to_token_pool.free(req.req_pool_idx) + running_batch.req_pool_indices = torch.cat( + (running_batch.req_pool_indices[:idx], running_batch.req_pool_indices[idx + 1 :]) + ) diff --git a/src/parallax/utils/mac_prefill_addr.py b/src/parallax/utils/mac_prefill_addr.py new file mode 100644 index 00000000..199c63ec --- /dev/null +++ b/src/parallax/utils/mac_prefill_addr.py @@ -0,0 +1,82 @@ +from enum import Enum, auto +from typing import Optional + +from parallax.server.cache_manager import CacheManager +from parallax.server.request import Request + + +class AddReqResult(Enum): + CONTINUE = auto() # Continue to add requests + NO_TOKEN = auto() # No token left + OTHER = auto() # Other reasons to stop adding requests + + +class MACPrefillAdder: + """ + MACPrefillAdder is a class that adds prefill requests to the MAC prefill batch. + """ + + def __init__( + self, + page_size: int, + rem_chunk_tokens: Optional[int], + cache_manager: CacheManager, + ): + self.page_size = page_size + self.rem_chunk_tokens = rem_chunk_tokens + self.can_run_list = [] + self.new_chunked_req = None + self.cache_manager = cache_manager + + def add_chunked_req(self, chunked_req: Request) -> Optional[Request]: + if chunked_req is None: + return None + matched_tokens = 0 + if self.cache_manager.prefix_cache is not None: + _, matched_tokens = self.cache_manager.prefix_cache.match_prefix( + chunked_req.origin_input_ids + ) + extend_input_len = len(chunked_req.origin_input_ids) - matched_tokens + extend_input_len = 1 if extend_input_len <= 0 else extend_input_len + if self.rem_chunk_tokens is None: + truncated = False + chunked_req_offset = len(chunked_req.origin_input_ids) + else: + truncated = extend_input_len > self.rem_chunk_tokens + chunked_req_offset = min(self.rem_chunk_tokens, extend_input_len) + matched_tokens + chunked_req.input_ids = chunked_req.origin_input_ids[:chunked_req_offset] + chunked_req._effective_total_length = chunked_req_offset + self.can_run_list.append(chunked_req) + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) + return chunked_req if truncated else None + + def add_one_req(self, req: Request) -> AddReqResult: + matched_tokens = 0 + if self.cache_manager.prefix_cache is not None: + _, matched_tokens = self.cache_manager.prefix_cache.match_prefix(req.origin_input_ids) + extend_input_len = len(req.origin_input_ids) - matched_tokens + extend_input_len = 1 if extend_input_len <= 0 else extend_input_len + # align to page size + extend_input_len = ( + (extend_input_len + self.page_size - 1) // self.page_size * self.page_size + ) + if self.rem_chunk_tokens is None or extend_input_len <= self.rem_chunk_tokens: + self.can_run_list.append(req) + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= extend_input_len + else: + # make sure at least one page is available + trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size + if trunc_len <= 0: + return AddReqResult.OTHER + extend_input_len = trunc_len + chunked_req_offset = extend_input_len + matched_tokens + req.input_ids = req.origin_input_ids[:chunked_req_offset] + req._effective_total_length = chunked_req_offset + self.can_run_list.append(req) + self.new_chunked_req = req + self.rem_chunk_tokens -= extend_input_len + if self.rem_chunk_tokens is None or self.rem_chunk_tokens <= 0: + return AddReqResult.OTHER + return AddReqResult.CONTINUE diff --git a/tests/test_executor.py b/tests/test_executor.py index e20d4174..3586d8e2 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -65,7 +65,9 @@ def create_executor(start_layer, end_layer, device, kv_cache_memory_fraction=0.3 def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): - """Run executor pipeline stage. Input and output should be requests""" + """Run executor pipeline stage. Input and output should be requests. + Returns (to_forward_reqs, batch_output); chunked_reqs are handled locally via handle_input_requests. + """ executor.handle_input_requests(requests) executor.scheduler.admit_requests() input_batch = executor.scheduler.form_batch() @@ -73,12 +75,14 @@ def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): assert prepared_batch is not None, "Failed to prepare batch inputs" batch_data = prepared_batch[batch_type] batch_output = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) - output_reqs = executor.prepare_next_batch_requests( + chunked_reqs, to_forward_reqs = executor.prepare_next_batch_requests( requests=batch_data["requests"], batch_output=batch_output, context_lengths=batch_data.get("context_lengths"), ) - return output_reqs, batch_output + if chunked_reqs: + executor.handle_input_requests(chunked_reqs) + return to_forward_reqs, batch_output @pytest.mark.parametrize( @@ -119,8 +123,8 @@ def test_decode_pipeline_multiple_steps(pipeline_devices, pp_end_layers, num_dec ref_cuda_model = AutoModelForCausalLM.from_pretrained( CUDA_MODEL_REPO, torch_dtype=torch.bfloat16, - device_map="cuda:0", ) + ref_cuda_model = ref_cuda_model.to("cuda:0") ref_cuda_tokenizer = AutoTokenizer.from_pretrained(CUDA_MODEL_REPO) if ref_cuda_tokenizer.pad_token is None: ref_cuda_tokenizer.pad_token = ref_cuda_tokenizer.eos_token diff --git a/tests/test_server_args.py b/tests/test_server_args.py index ac523b3c..ae6b6f20 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -143,6 +143,8 @@ def test_create_config(self): max_lora_chunk_size=128, enable_weight_refit=False, weight_refit_mode="cpu", + chunked_prefill_size=128, + max_prefill_tokens=1024, gpu_backend="sglang", )